"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA. Hooks into a base model to dynamically route hidden states through multiple extracted expert adapters using fuzzy text-based domain priors or dynamic learnable token-level gates. """ from __future__ import annotations import re from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn from parasitic_qlora import ExpertAdapter class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore] """A lightweight learnable MLP gating network for routing.""" def __init__(self, in_features: int, num_adapters: int) -> None: super().__init__() self.in_features = in_features self.gate = nn.Sequential( nn.Linear(in_features, in_features // 2), nn.ReLU(), nn.Linear(in_features // 2, num_adapters), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Input: [..., in_features] or other dimension to be pooled # Output: [..., num_adapters] (unnormalized scores) if x.shape[-1] != self.in_features: import torch.nn.functional as F orig_shape = x.shape flat_x = x.view(-1, orig_shape[-1]).unsqueeze(1) pooled_x = F.adaptive_avg_pool1d(flat_x, self.in_features) x = pooled_x.squeeze(1).view(*orig_shape[:-1], self.in_features) return self.gate(x) # type: ignore[no-any-return, unused-ignore] class ExpertAdapterRouter: """Manages dynamic MoE-style routing over a library of LoRA adapters.""" def __init__( self, base_model: nn.Module, adapter_library: List[ExpertAdapter], in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m) ) -> None: self.base_model = base_model self.adapter_library = adapter_library self.num_adapters = len(adapter_library) self.hooks: List[torch.utils.hooks.RemovableHandle] = [] # Learnable gating network self.gate = LearnableGate(in_features, self.num_adapters).to( next(base_model.parameters()).device ) # Active weights for current forward pass (batch size × num_adapters) self.current_gate_weights: Optional[torch.Tensor] = None def compute_fuzzy_priors(self, text: str) -> torch.Tensor: """Determines static routing priors based on keyword matching in input text.""" priors = torch.zeros(self.num_adapters) # Heuristics for legal exam domains text_lower = text.lower() has_statute = bool(re.search(r"§\s*\d+", text_lower)) has_logic = ( "firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower ) has_style = ( "gutachtenstil" in text_lower or "gutachten" in text_lower or "klausur" in text_lower ) for idx, adapter in enumerate(self.adapter_library): score = 0.1 # base prior for tag in adapter.domain_tags: if tag == "statute_recall" and has_statute: score += 0.8 elif tag == "logic_reasoning" and has_logic: score += 0.8 elif tag == "style_gutachtenstil" and has_style: score += 0.8 priors[idx] = score # Softmax normalize return torch.softmax(priors, dim=0) def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None: """Explicitly sets the routing weights for the next forward pass.""" self.current_gate_weights = fuzzy_priors def register_hooks(self) -> None: """Attaches forward hooks to linear layers present in the adapter library.""" self.unregister_hooks() # Find all layers in the base model that have adapters adapter_layers: set[str] = set() for adapter in self.adapter_library: adapter_layers.update(adapter.layers.keys()) # Bind hooks dynamically for name, module in self.base_model.named_modules(): # Check if this specific module has an adapter # We match using suffix to support model wrapping/prefixes matching_adapter_name = None for layer_name in adapter_layers: clean_layer_name = layer_name.replace(".weight", "").replace( ".bias", "" ) if name.endswith(clean_layer_name) or name == clean_layer_name: matching_adapter_name = layer_name break if matching_adapter_name and isinstance(module, nn.Linear): hook = module.register_forward_hook( self._make_hook_fn(matching_adapter_name) ) self.hooks.append(hook) def unregister_hooks(self) -> None: """Removes all registered hooks from the base model.""" for hook in self.hooks: hook.remove() self.hooks.clear() def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]: """Creates the hook function for a specific linear layer.""" def hook_fn( module: nn.Module, input_tensor: Tuple[torch.Tensor, ...], output_tensor: torch.Tensor, ) -> torch.Tensor: x = input_tensor[0] # [batch, seq_len, in_features] # Calculate gate weights if self.current_gate_weights is not None: # Use manually set priors (e.g. fuzzy text-based) # Expand to match batch size batch_size = x.shape[0] weights = ( self.current_gate_weights.to(x.device) .unsqueeze(0) .expand(batch_size, -1) ) else: # Compute dynamically per token via learnable gate # We pool over sequence length or route per token # Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters] gate_logits = self.gate(x) weights = torch.softmax(gate_logits, dim=-1) # Compute combined low-rank contribution # Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t() adapter_output = torch.zeros_like(output_tensor) for i, adapter in enumerate(self.adapter_library): if layer_name in adapter.layers: lm = adapter.layers[layer_name] # Ensure tensors are on the correct device lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype) lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype) # Dynamic scaling: gate_weight for this adapter # weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters] if len(weights.shape) == 3: # Token-level routing: shape [batch, seq_len, 1] g = weights[..., i : i + 1] else: # Batch-level routing: shape [batch, 1, 1] g = weights[:, i].view(-1, 1, 1) # Low-rank projection x_proj = torch.matmul(x, lora_A.t()) y_proj = torch.matmul(x_proj, lora_B.t()) # Accumulate scaled delta adapter_output += g * y_proj return output_tensor + adapter_output return hook_fn