297 lines
12 KiB
Python
297 lines
12 KiB
Python
"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA.
|
||
|
||
Hooks into a base model to dynamically route hidden states through multiple
|
||
extracted expert adapters using fuzzy text-based domain priors or dynamic
|
||
learnable token-level gates.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from typing import Any, Callable, List, Optional, Tuple
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
from parasitic_qlora import ExpertAdapter
|
||
from representation_engineering import SkillVectorLibrary, ProcessVectorLibrary
|
||
|
||
|
||
class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
|
||
"""A lightweight learnable MLP gating network for routing."""
|
||
|
||
def __init__(self, in_features: int, num_adapters: int) -> None:
|
||
super().__init__()
|
||
self.in_features = in_features
|
||
self.gate = nn.Sequential(
|
||
nn.Linear(in_features, in_features // 2),
|
||
nn.ReLU(),
|
||
nn.Linear(in_features // 2, num_adapters),
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
# Input: [..., in_features] or other dimension to be pooled
|
||
# Output: [..., num_adapters] (unnormalized scores)
|
||
if x.shape[-1] != self.in_features:
|
||
import torch.nn.functional as F
|
||
|
||
orig_shape = x.shape
|
||
flat_x = x.view(-1, orig_shape[-1]).unsqueeze(1)
|
||
pooled_x = F.adaptive_avg_pool1d(flat_x, self.in_features)
|
||
x = pooled_x.squeeze(1).view(*orig_shape[:-1], self.in_features)
|
||
return self.gate(x) # type: ignore[no-any-return, unused-ignore]
|
||
|
||
|
||
class ExpertAdapterRouter:
|
||
"""Manages dynamic MoE-style routing over a library of LoRA adapters and representation vectors."""
|
||
|
||
def __init__(
|
||
self,
|
||
base_model: nn.Module,
|
||
adapter_library: Optional[List[ExpertAdapter]] = None,
|
||
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
|
||
skill_library: Optional[SkillVectorLibrary] = None,
|
||
process_library: Optional[ProcessVectorLibrary] = None,
|
||
steering_alpha: float = 1.0,
|
||
steering_mode: str = "token", # "token" or "prompt"
|
||
) -> None:
|
||
self.base_model = base_model
|
||
self.adapter_library = adapter_library or []
|
||
self.num_adapters = len(self.adapter_library)
|
||
self.skill_library = skill_library
|
||
self.process_library = process_library
|
||
self.steering_alpha = steering_alpha
|
||
self.steering_mode = steering_mode
|
||
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
|
||
self.active_process_id: Optional[str] = None
|
||
self.active_process_step: Optional[int] = None
|
||
|
||
# Sorted list of skill IDs for index-based routing
|
||
self.skill_ids = (
|
||
sorted(list(self.skill_library.vectors.keys()))
|
||
if self.skill_library
|
||
else []
|
||
)
|
||
|
||
# Learnable gating network for adapters
|
||
if self.num_adapters > 0:
|
||
self.gate = LearnableGate(in_features, self.num_adapters).to(
|
||
next(base_model.parameters()).device
|
||
)
|
||
else:
|
||
self.gate = None
|
||
|
||
# Learnable gating network for skills
|
||
if len(self.skill_ids) > 0:
|
||
self.skill_gate = LearnableGate(in_features, len(self.skill_ids)).to(
|
||
next(base_model.parameters()).device
|
||
)
|
||
else:
|
||
self.skill_gate = None
|
||
|
||
# Active weights for current forward pass (batch size × num_adapters/skills)
|
||
self.current_gate_weights: Optional[torch.Tensor] = None
|
||
|
||
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
|
||
"""Determines static routing priors based on keyword matching in input text."""
|
||
priors = torch.zeros(self.num_adapters)
|
||
|
||
# Heuristics for legal exam domains
|
||
text_lower = text.lower()
|
||
has_statute = bool(re.search(r"§\s*\d+", text_lower))
|
||
has_logic = (
|
||
"firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower
|
||
)
|
||
has_style = (
|
||
"gutachtenstil" in text_lower
|
||
or "gutachten" in text_lower
|
||
or "klausur" in text_lower
|
||
)
|
||
|
||
for idx, adapter in enumerate(self.adapter_library):
|
||
score = 0.1 # base prior
|
||
for tag in adapter.domain_tags:
|
||
if tag == "statute_recall" and has_statute:
|
||
score += 0.8
|
||
elif tag == "logic_reasoning" and has_logic:
|
||
score += 0.8
|
||
elif tag == "style_gutachtenstil" and has_style:
|
||
score += 0.8
|
||
priors[idx] = score
|
||
|
||
# Softmax normalize
|
||
return torch.softmax(priors, dim=0)
|
||
|
||
def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None:
|
||
"""Explicitly sets the routing weights for the next forward pass."""
|
||
self.current_gate_weights = fuzzy_priors
|
||
|
||
def register_hooks(self) -> None:
|
||
"""Attaches forward hooks to linear layers (adapters) and transformer blocks (steering)."""
|
||
self.unregister_hooks()
|
||
|
||
# 1. Bind adapter hooks if adapters are present
|
||
if self.num_adapters > 0:
|
||
adapter_layers: set[str] = set()
|
||
for adapter in self.adapter_library:
|
||
adapter_layers.update(adapter.layers.keys())
|
||
|
||
for name, module in self.base_model.named_modules():
|
||
matching_adapter_name = None
|
||
for layer_name in adapter_layers:
|
||
clean_layer_name = layer_name.replace(".weight", "").replace(
|
||
".bias", ""
|
||
)
|
||
if name.endswith(clean_layer_name) or name == clean_layer_name:
|
||
matching_adapter_name = layer_name
|
||
break
|
||
|
||
if matching_adapter_name and isinstance(module, nn.Linear):
|
||
hook = module.register_forward_hook(
|
||
self._make_hook_fn(matching_adapter_name)
|
||
)
|
||
self.hooks.append(hook)
|
||
|
||
# 2. Bind steering hooks if skill_library or process_library is present
|
||
if self.skill_library or self.process_library:
|
||
transformer_layers = []
|
||
for name, module in self.base_model.named_modules():
|
||
match = re.match(r".*layers?\.(\d+)$", name)
|
||
if match:
|
||
layer_idx = int(match.group(1))
|
||
transformer_layers.append((layer_idx, name, module))
|
||
|
||
# Sort by layer_idx to ensure consistent mapping
|
||
transformer_layers.sort(key=lambda x: x[0])
|
||
|
||
for layer_idx, name, module in transformer_layers:
|
||
hook = module.register_forward_hook(
|
||
self._make_steering_hook_fn(layer_idx)
|
||
)
|
||
self.hooks.append(hook)
|
||
|
||
def unregister_hooks(self) -> None:
|
||
"""Removes all registered hooks from the base model."""
|
||
for hook in self.hooks:
|
||
hook.remove()
|
||
self.hooks.clear()
|
||
|
||
def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]:
|
||
"""Creates the hook function for a specific linear layer."""
|
||
|
||
def hook_fn(
|
||
module: nn.Module,
|
||
input_tensor: Tuple[torch.Tensor, ...],
|
||
output_tensor: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
x = input_tensor[0] # [batch, seq_len, in_features]
|
||
|
||
# Calculate gate weights
|
||
if self.current_gate_weights is not None:
|
||
# Use manually set priors (e.g. fuzzy text-based)
|
||
# Expand to match batch size
|
||
batch_size = x.shape[0]
|
||
weights = (
|
||
self.current_gate_weights.to(x.device)
|
||
.unsqueeze(0)
|
||
.expand(batch_size, -1)
|
||
)
|
||
else:
|
||
# Compute dynamically per token via learnable gate
|
||
if self.gate is not None:
|
||
gate_logits = self.gate(x)
|
||
weights = torch.softmax(gate_logits, dim=-1)
|
||
else:
|
||
return output_tensor
|
||
|
||
# Compute combined low-rank contribution
|
||
adapter_output = torch.zeros_like(output_tensor)
|
||
|
||
for i, adapter in enumerate(self.adapter_library):
|
||
if layer_name in adapter.layers:
|
||
lm = adapter.layers[layer_name]
|
||
lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype)
|
||
lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype)
|
||
|
||
if len(weights.shape) == 3:
|
||
g = weights[..., i : i + 1]
|
||
else:
|
||
g = weights[:, i].view(-1, 1, 1)
|
||
|
||
x_proj = torch.matmul(x, lora_A.t())
|
||
y_proj = torch.matmul(x_proj, lora_B.t())
|
||
|
||
adapter_output += g * y_proj
|
||
|
||
return output_tensor + adapter_output
|
||
|
||
return hook_fn
|
||
|
||
def _make_steering_hook_fn(self, layer_idx: int) -> Callable[..., Any]:
|
||
"""Creates a hook function to inject activation steering vectors at a specific layer."""
|
||
|
||
def hook_fn(
|
||
module: nn.Module,
|
||
input_tensor: Tuple[torch.Tensor, ...],
|
||
output_tensor: Any,
|
||
) -> Any:
|
||
is_tuple = isinstance(output_tensor, tuple)
|
||
x = output_tensor[0] if is_tuple else output_tensor
|
||
|
||
# Sequential process/workflow steering
|
||
if self.active_process_id is not None and self.process_library is not None:
|
||
step_idx = self.active_process_step or 0
|
||
step_vector = self.process_library.get_process_step(
|
||
self.active_process_id, step_idx
|
||
)
|
||
if step_vector and layer_idx in step_vector.layer_vectors:
|
||
v = step_vector.layer_vectors[layer_idx].to(
|
||
device=x.device, dtype=x.dtype
|
||
)
|
||
steered_x = x + self.steering_alpha * v
|
||
if is_tuple:
|
||
return (steered_x,) + output_tensor[1:]
|
||
return steered_x
|
||
return output_tensor
|
||
|
||
# Dynamic skill routing
|
||
if self.skill_library and len(self.skill_ids) > 0:
|
||
weights = None
|
||
if self.current_gate_weights is not None:
|
||
batch_size = x.shape[0]
|
||
weights = (
|
||
self.current_gate_weights.to(x.device)
|
||
.unsqueeze(0)
|
||
.expand(batch_size, -1)
|
||
)
|
||
elif self.skill_gate is not None:
|
||
if self.steering_mode == "token":
|
||
gate_logits = self.skill_gate(x)
|
||
weights = torch.softmax(gate_logits, dim=-1)
|
||
else:
|
||
x_mean = x.mean(dim=1) if len(x.shape) == 3 else x
|
||
gate_logits = self.skill_gate(x_mean)
|
||
weights = torch.softmax(gate_logits, dim=-1)
|
||
|
||
if weights is not None:
|
||
steer_contribution = torch.zeros_like(x)
|
||
for i, skill_id in enumerate(self.skill_ids):
|
||
vec = self.skill_library.get_vector(skill_id)
|
||
if vec and layer_idx in vec.layer_vectors:
|
||
v = vec.layer_vectors[layer_idx].to(
|
||
device=x.device, dtype=x.dtype
|
||
)
|
||
if len(weights.shape) == 3:
|
||
g = weights[..., i : i + 1]
|
||
else:
|
||
g = weights[:, i].view(-1, 1, 1)
|
||
steer_contribution += g * v
|
||
|
||
steered_x = x + self.steering_alpha * steer_contribution
|
||
if is_tuple:
|
||
return (steered_x,) + output_tensor[1:]
|
||
return steered_x
|
||
|
||
return output_tensor
|
||
|
||
return hook_fn
|