Files
FCES-native/python/adapter_moe_router.py

297 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA.
Hooks into a base model to dynamically route hidden states through multiple
extracted expert adapters using fuzzy text-based domain priors or dynamic
learnable token-level gates.
"""
from __future__ import annotations
import re
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter
from representation_engineering import SkillVectorLibrary, ProcessVectorLibrary
class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
"""A lightweight learnable MLP gating network for routing."""
def __init__(self, in_features: int, num_adapters: int) -> None:
super().__init__()
self.in_features = in_features
self.gate = nn.Sequential(
nn.Linear(in_features, in_features // 2),
nn.ReLU(),
nn.Linear(in_features // 2, num_adapters),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Input: [..., in_features] or other dimension to be pooled
# Output: [..., num_adapters] (unnormalized scores)
if x.shape[-1] != self.in_features:
import torch.nn.functional as F
orig_shape = x.shape
flat_x = x.view(-1, orig_shape[-1]).unsqueeze(1)
pooled_x = F.adaptive_avg_pool1d(flat_x, self.in_features)
x = pooled_x.squeeze(1).view(*orig_shape[:-1], self.in_features)
return self.gate(x) # type: ignore[no-any-return, unused-ignore]
class ExpertAdapterRouter:
"""Manages dynamic MoE-style routing over a library of LoRA adapters and representation vectors."""
def __init__(
self,
base_model: nn.Module,
adapter_library: Optional[List[ExpertAdapter]] = None,
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
skill_library: Optional[SkillVectorLibrary] = None,
process_library: Optional[ProcessVectorLibrary] = None,
steering_alpha: float = 1.0,
steering_mode: str = "token", # "token" or "prompt"
) -> None:
self.base_model = base_model
self.adapter_library = adapter_library or []
self.num_adapters = len(self.adapter_library)
self.skill_library = skill_library
self.process_library = process_library
self.steering_alpha = steering_alpha
self.steering_mode = steering_mode
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
self.active_process_id: Optional[str] = None
self.active_process_step: Optional[int] = None
# Sorted list of skill IDs for index-based routing
self.skill_ids = (
sorted(list(self.skill_library.vectors.keys()))
if self.skill_library
else []
)
# Learnable gating network for adapters
if self.num_adapters > 0:
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
)
else:
self.gate = None
# Learnable gating network for skills
if len(self.skill_ids) > 0:
self.skill_gate = LearnableGate(in_features, len(self.skill_ids)).to(
next(base_model.parameters()).device
)
else:
self.skill_gate = None
# Active weights for current forward pass (batch size × num_adapters/skills)
self.current_gate_weights: Optional[torch.Tensor] = None
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
"""Determines static routing priors based on keyword matching in input text."""
priors = torch.zeros(self.num_adapters)
# Heuristics for legal exam domains
text_lower = text.lower()
has_statute = bool(re.search(r"§\s*\d+", text_lower))
has_logic = (
"firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower
)
has_style = (
"gutachtenstil" in text_lower
or "gutachten" in text_lower
or "klausur" in text_lower
)
for idx, adapter in enumerate(self.adapter_library):
score = 0.1 # base prior
for tag in adapter.domain_tags:
if tag == "statute_recall" and has_statute:
score += 0.8
elif tag == "logic_reasoning" and has_logic:
score += 0.8
elif tag == "style_gutachtenstil" and has_style:
score += 0.8
priors[idx] = score
# Softmax normalize
return torch.softmax(priors, dim=0)
def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None:
"""Explicitly sets the routing weights for the next forward pass."""
self.current_gate_weights = fuzzy_priors
def register_hooks(self) -> None:
"""Attaches forward hooks to linear layers (adapters) and transformer blocks (steering)."""
self.unregister_hooks()
# 1. Bind adapter hooks if adapters are present
if self.num_adapters > 0:
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
for name, module in self.base_model.named_modules():
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
if matching_adapter_name and isinstance(module, nn.Linear):
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
)
self.hooks.append(hook)
# 2. Bind steering hooks if skill_library or process_library is present
if self.skill_library or self.process_library:
transformer_layers = []
for name, module in self.base_model.named_modules():
match = re.match(r".*layers?\.(\d+)$", name)
if match:
layer_idx = int(match.group(1))
transformer_layers.append((layer_idx, name, module))
# Sort by layer_idx to ensure consistent mapping
transformer_layers.sort(key=lambda x: x[0])
for layer_idx, name, module in transformer_layers:
hook = module.register_forward_hook(
self._make_steering_hook_fn(layer_idx)
)
self.hooks.append(hook)
def unregister_hooks(self) -> None:
"""Removes all registered hooks from the base model."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]:
"""Creates the hook function for a specific linear layer."""
def hook_fn(
module: nn.Module,
input_tensor: Tuple[torch.Tensor, ...],
output_tensor: torch.Tensor,
) -> torch.Tensor:
x = input_tensor[0] # [batch, seq_len, in_features]
# Calculate gate weights
if self.current_gate_weights is not None:
# Use manually set priors (e.g. fuzzy text-based)
# Expand to match batch size
batch_size = x.shape[0]
weights = (
self.current_gate_weights.to(x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
else:
# Compute dynamically per token via learnable gate
if self.gate is not None:
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
else:
return output_tensor
# Compute combined low-rank contribution
adapter_output = torch.zeros_like(output_tensor)
for i, adapter in enumerate(self.adapter_library):
if layer_name in adapter.layers:
lm = adapter.layers[layer_name]
lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype)
lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype)
if len(weights.shape) == 3:
g = weights[..., i : i + 1]
else:
g = weights[:, i].view(-1, 1, 1)
x_proj = torch.matmul(x, lora_A.t())
y_proj = torch.matmul(x_proj, lora_B.t())
adapter_output += g * y_proj
return output_tensor + adapter_output
return hook_fn
def _make_steering_hook_fn(self, layer_idx: int) -> Callable[..., Any]:
"""Creates a hook function to inject activation steering vectors at a specific layer."""
def hook_fn(
module: nn.Module,
input_tensor: Tuple[torch.Tensor, ...],
output_tensor: Any,
) -> Any:
is_tuple = isinstance(output_tensor, tuple)
x = output_tensor[0] if is_tuple else output_tensor
# Sequential process/workflow steering
if self.active_process_id is not None and self.process_library is not None:
step_idx = self.active_process_step or 0
step_vector = self.process_library.get_process_step(
self.active_process_id, step_idx
)
if step_vector and layer_idx in step_vector.layer_vectors:
v = step_vector.layer_vectors[layer_idx].to(
device=x.device, dtype=x.dtype
)
steered_x = x + self.steering_alpha * v
if is_tuple:
return (steered_x,) + output_tensor[1:]
return steered_x
return output_tensor
# Dynamic skill routing
if self.skill_library and len(self.skill_ids) > 0:
weights = None
if self.current_gate_weights is not None:
batch_size = x.shape[0]
weights = (
self.current_gate_weights.to(x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
elif self.skill_gate is not None:
if self.steering_mode == "token":
gate_logits = self.skill_gate(x)
weights = torch.softmax(gate_logits, dim=-1)
else:
x_mean = x.mean(dim=1) if len(x.shape) == 3 else x
gate_logits = self.skill_gate(x_mean)
weights = torch.softmax(gate_logits, dim=-1)
if weights is not None:
steer_contribution = torch.zeros_like(x)
for i, skill_id in enumerate(self.skill_ids):
vec = self.skill_library.get_vector(skill_id)
if vec and layer_idx in vec.layer_vectors:
v = vec.layer_vectors[layer_idx].to(
device=x.device, dtype=x.dtype
)
if len(weights.shape) == 3:
g = weights[..., i : i + 1]
else:
g = weights[:, i].view(-1, 1, 1)
steer_contribution += g * v
steered_x = x + self.steering_alpha * steer_contribution
if is_tuple:
return (steered_x,) + output_tensor[1:]
return steered_x
return output_tensor
return hook_fn