Files
FCES-native/python/adapter_moe_router.py

182 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA.
Hooks into a base model to dynamically route hidden states through multiple
extracted expert adapters using fuzzy text-based domain priors or dynamic
learnable token-level gates.
"""
from __future__ import annotations
import re
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter
class LearnableGate(nn.Module): # type: ignore[misc]
"""A lightweight learnable MLP gating network for routing."""
def __init__(self, in_features: int, num_adapters: int) -> None:
super().__init__()
self.gate = nn.Sequential(
nn.Linear(in_features, in_features // 2),
nn.ReLU(),
nn.Linear(in_features // 2, num_adapters),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Input: [..., in_features]
# Output: [..., num_adapters] (unnormalized scores)
return self.gate(x)
class ExpertAdapterRouter:
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
def __init__(
self,
base_model: nn.Module,
adapter_library: List[ExpertAdapter],
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
) -> None:
self.base_model = base_model
self.adapter_library = adapter_library
self.num_adapters = len(adapter_library)
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
# Learnable gating network
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
)
# Active weights for current forward pass (batch size × num_adapters)
self.current_gate_weights: Optional[torch.Tensor] = None
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
"""Determines static routing priors based on keyword matching in input text."""
priors = torch.zeros(self.num_adapters)
# Heuristics for legal exam domains
text_lower = text.lower()
has_statute = bool(re.search(r"§\s*\d+", text_lower))
has_logic = (
"firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower
)
has_style = (
"gutachtenstil" in text_lower
or "gutachten" in text_lower
or "klausur" in text_lower
)
for idx, adapter in enumerate(self.adapter_library):
score = 0.1 # base prior
for tag in adapter.domain_tags:
if tag == "statute_recall" and has_statute:
score += 0.8
elif tag == "logic_reasoning" and has_logic:
score += 0.8
elif tag == "style_gutachtenstil" and has_style:
score += 0.8
priors[idx] = score
# Softmax normalize
return torch.softmax(priors, dim=0)
def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None:
"""Explicitly sets the routing weights for the next forward pass."""
self.current_gate_weights = fuzzy_priors
def register_hooks(self) -> None:
"""Attaches forward hooks to linear layers present in the adapter library."""
self.unregister_hooks()
# Find all layers in the base model that have adapters
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
# Bind hooks dynamically
for name, module in self.base_model.named_modules():
# Check if this specific module has an adapter
# We match using suffix to support model wrapping/prefixes
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
if matching_adapter_name and isinstance(module, nn.Linear):
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
)
self.hooks.append(hook)
def unregister_hooks(self) -> None:
"""Removes all registered hooks from the base model."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]:
"""Creates the hook function for a specific linear layer."""
def hook_fn(
module: nn.Module,
input_tensor: Tuple[torch.Tensor, ...],
output_tensor: torch.Tensor,
) -> torch.Tensor:
x = input_tensor[0] # [batch, seq_len, in_features]
# Calculate gate weights
if self.current_gate_weights is not None:
# Use manually set priors (e.g. fuzzy text-based)
# Expand to match batch size
batch_size = x.shape[0]
weights = (
self.current_gate_weights.to(x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
else:
# Compute dynamically per token via learnable gate
# We pool over sequence length or route per token
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
# Compute combined low-rank contribution
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
adapter_output = torch.zeros_like(output_tensor)
for i, adapter in enumerate(self.adapter_library):
if layer_name in adapter.layers:
lm = adapter.layers[layer_name]
# Ensure tensors are on the correct device
lora_A = lm.lora_A.to(x.device)
lora_B = lm.lora_B.to(x.device)
# Dynamic scaling: gate_weight for this adapter
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
if len(weights.shape) == 3:
# Token-level routing: shape [batch, seq_len, 1]
g = weights[..., i : i + 1]
else:
# Batch-level routing: shape [batch, 1, 1]
g = weights[:, i].view(-1, 1, 1)
# Low-rank projection
x_proj = torch.matmul(x, lora_A.t())
y_proj = torch.matmul(x_proj, lora_B.t())
# Accumulate scaled delta
adapter_output += g * y_proj
return output_tensor + adapter_output
return hook_fn