"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA. Hooks into a base model to dynamically route hidden states through multiple extracted expert adapters using fuzzy text-based domain priors or dynamic learnable token-level gates. """ from __future__ import annotations import re from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn as nn from parasitic_qlora import ExpertAdapter from representation_engineering import SkillVectorLibrary, ProcessVectorLibrary class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore] """A lightweight learnable MLP gating network for routing.""" def __init__(self, in_features: int, num_adapters: int) -> None: super().__init__() self.in_features = in_features self.gate = nn.Sequential( nn.Linear(in_features, in_features // 2), nn.ReLU(), nn.Linear(in_features // 2, num_adapters), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Input: [..., in_features] or other dimension to be pooled # Output: [..., num_adapters] (unnormalized scores) if x.shape[-1] != self.in_features: import torch.nn.functional as F orig_shape = x.shape flat_x = x.view(-1, orig_shape[-1]).unsqueeze(1) pooled_x = F.adaptive_avg_pool1d(flat_x, self.in_features) x = pooled_x.squeeze(1).view(*orig_shape[:-1], self.in_features) return self.gate(x) # type: ignore[no-any-return, unused-ignore] class ExpertAdapterRouter: """Manages dynamic MoE-style routing over a library of LoRA adapters and representation vectors.""" def __init__( self, base_model: nn.Module, adapter_library: Optional[List[ExpertAdapter]] = None, in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m) skill_library: Optional[SkillVectorLibrary] = None, process_library: Optional[ProcessVectorLibrary] = None, steering_alpha: float = 1.0, steering_mode: str = "token", # "token" or "prompt" ) -> None: self.base_model = base_model self.adapter_library = adapter_library or [] self.num_adapters = len(self.adapter_library) self.skill_library = skill_library self.process_library = process_library self.steering_alpha = steering_alpha self.steering_mode = steering_mode self.hooks: List[torch.utils.hooks.RemovableHandle] = [] self.active_process_id: Optional[str] = None self.active_process_step: Optional[int] = None # Sorted list of skill IDs for index-based routing self.skill_ids = ( sorted(list(self.skill_library.vectors.keys())) if self.skill_library else [] ) # Learnable gating network for adapters if self.num_adapters > 0: self.gate = LearnableGate(in_features, self.num_adapters).to( next(base_model.parameters()).device ) else: self.gate = None # Learnable gating network for skills if len(self.skill_ids) > 0: self.skill_gate = LearnableGate(in_features, len(self.skill_ids)).to( next(base_model.parameters()).device ) else: self.skill_gate = None # Active weights for current forward pass (batch size × num_adapters/skills) self.current_gate_weights: Optional[torch.Tensor] = None def compute_fuzzy_priors(self, text: str) -> torch.Tensor: """Determines static routing priors based on keyword matching in input text.""" priors = torch.zeros(self.num_adapters) # Heuristics for legal exam domains text_lower = text.lower() has_statute = bool(re.search(r"§\s*\d+", text_lower)) has_logic = ( "firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower ) has_style = ( "gutachtenstil" in text_lower or "gutachten" in text_lower or "klausur" in text_lower ) for idx, adapter in enumerate(self.adapter_library): score = 0.1 # base prior for tag in adapter.domain_tags: if tag == "statute_recall" and has_statute: score += 0.8 elif tag == "logic_reasoning" and has_logic: score += 0.8 elif tag == "style_gutachtenstil" and has_style: score += 0.8 priors[idx] = score # Softmax normalize return torch.softmax(priors, dim=0) def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None: """Explicitly sets the routing weights for the next forward pass.""" self.current_gate_weights = fuzzy_priors def register_hooks(self) -> None: """Attaches forward hooks to linear layers (adapters) and transformer blocks (steering).""" self.unregister_hooks() # 1. Bind adapter hooks if adapters are present if self.num_adapters > 0: adapter_layers: set[str] = set() for adapter in self.adapter_library: adapter_layers.update(adapter.layers.keys()) for name, module in self.base_model.named_modules(): matching_adapter_name = None for layer_name in adapter_layers: clean_layer_name = layer_name.replace(".weight", "").replace( ".bias", "" ) if name.endswith(clean_layer_name) or name == clean_layer_name: matching_adapter_name = layer_name break if matching_adapter_name and isinstance(module, nn.Linear): hook = module.register_forward_hook( self._make_hook_fn(matching_adapter_name) ) self.hooks.append(hook) # 2. Bind steering hooks if skill_library or process_library is present if self.skill_library or self.process_library: transformer_layers = [] for name, module in self.base_model.named_modules(): match = re.match(r".*layers?\.(\d+)$", name) if match: layer_idx = int(match.group(1)) transformer_layers.append((layer_idx, name, module)) # Sort by layer_idx to ensure consistent mapping transformer_layers.sort(key=lambda x: x[0]) for layer_idx, name, module in transformer_layers: hook = module.register_forward_hook( self._make_steering_hook_fn(layer_idx) ) self.hooks.append(hook) def unregister_hooks(self) -> None: """Removes all registered hooks from the base model.""" for hook in self.hooks: hook.remove() self.hooks.clear() def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]: """Creates the hook function for a specific linear layer.""" def hook_fn( module: nn.Module, input_tensor: Tuple[torch.Tensor, ...], output_tensor: torch.Tensor, ) -> torch.Tensor: x = input_tensor[0] # [batch, seq_len, in_features] # Calculate gate weights if self.current_gate_weights is not None: # Use manually set priors (e.g. fuzzy text-based) # Expand to match batch size batch_size = x.shape[0] weights = ( self.current_gate_weights.to(x.device) .unsqueeze(0) .expand(batch_size, -1) ) else: # Compute dynamically per token via learnable gate if self.gate is not None: gate_logits = self.gate(x) weights = torch.softmax(gate_logits, dim=-1) else: return output_tensor # Compute combined low-rank contribution adapter_output = torch.zeros_like(output_tensor) for i, adapter in enumerate(self.adapter_library): if layer_name in adapter.layers: lm = adapter.layers[layer_name] lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype) lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype) if len(weights.shape) == 3: g = weights[..., i : i + 1] else: g = weights[:, i].view(-1, 1, 1) x_proj = torch.matmul(x, lora_A.t()) y_proj = torch.matmul(x_proj, lora_B.t()) adapter_output += g * y_proj return output_tensor + adapter_output return hook_fn def _make_steering_hook_fn(self, layer_idx: int) -> Callable[..., Any]: """Creates a hook function to inject activation steering vectors at a specific layer.""" def hook_fn( module: nn.Module, input_tensor: Tuple[torch.Tensor, ...], output_tensor: Any, ) -> Any: is_tuple = isinstance(output_tensor, tuple) x = output_tensor[0] if is_tuple else output_tensor # Sequential process/workflow steering if self.active_process_id is not None and self.process_library is not None: step_idx = self.active_process_step or 0 step_vector = self.process_library.get_process_step( self.active_process_id, step_idx ) if step_vector and layer_idx in step_vector.layer_vectors: v = step_vector.layer_vectors[layer_idx].to( device=x.device, dtype=x.dtype ) steered_x = x + self.steering_alpha * v if is_tuple: return (steered_x,) + output_tensor[1:] return steered_x return output_tensor # Dynamic skill routing if self.skill_library and len(self.skill_ids) > 0: weights = None if self.current_gate_weights is not None: batch_size = x.shape[0] weights = ( self.current_gate_weights.to(x.device) .unsqueeze(0) .expand(batch_size, -1) ) elif self.skill_gate is not None: if self.steering_mode == "token": gate_logits = self.skill_gate(x) weights = torch.softmax(gate_logits, dim=-1) else: x_mean = x.mean(dim=1) if len(x.shape) == 3 else x gate_logits = self.skill_gate(x_mean) weights = torch.softmax(gate_logits, dim=-1) if weights is not None: steer_contribution = torch.zeros_like(x) for i, skill_id in enumerate(self.skill_ids): vec = self.skill_library.get_vector(skill_id) if vec and layer_idx in vec.layer_vectors: v = vec.layer_vectors[layer_idx].to( device=x.device, dtype=x.dtype ) if len(weights.shape) == 3: g = weights[..., i : i + 1] else: g = weights[:, i].view(-1, 1, 1) steer_contribution += g * v steered_x = x + self.steering_alpha * steer_contribution if is_tuple: return (steered_x,) + output_tensor[1:] return steered_x return output_tensor return hook_fn