feat: expert manifold alignment, MoE router, FCES controller metadata bindings
This commit is contained in:
181
python/adapter_moe_router.py
Normal file
181
python/adapter_moe_router.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA.
|
||||
|
||||
Hooks into a base model to dynamically route hidden states through multiple
|
||||
extracted expert adapters using fuzzy text-based domain priors or dynamic
|
||||
learnable token-level gates.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from parasitic_qlora import ExpertAdapter
|
||||
|
||||
|
||||
class LearnableGate(nn.Module): # type: ignore[misc]
|
||||
"""A lightweight learnable MLP gating network for routing."""
|
||||
|
||||
def __init__(self, in_features: int, num_adapters: int) -> None:
|
||||
super().__init__()
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(in_features, in_features // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features // 2, num_adapters),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Input: [..., in_features]
|
||||
# Output: [..., num_adapters] (unnormalized scores)
|
||||
return self.gate(x)
|
||||
|
||||
|
||||
class ExpertAdapterRouter:
|
||||
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model: nn.Module,
|
||||
adapter_library: List[ExpertAdapter],
|
||||
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
|
||||
) -> None:
|
||||
self.base_model = base_model
|
||||
self.adapter_library = adapter_library
|
||||
self.num_adapters = len(adapter_library)
|
||||
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
|
||||
|
||||
# Learnable gating network
|
||||
self.gate = LearnableGate(in_features, self.num_adapters).to(
|
||||
next(base_model.parameters()).device
|
||||
)
|
||||
|
||||
# Active weights for current forward pass (batch size × num_adapters)
|
||||
self.current_gate_weights: Optional[torch.Tensor] = None
|
||||
|
||||
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
|
||||
"""Determines static routing priors based on keyword matching in input text."""
|
||||
priors = torch.zeros(self.num_adapters)
|
||||
|
||||
# Heuristics for legal exam domains
|
||||
text_lower = text.lower()
|
||||
has_statute = bool(re.search(r"§\s*\d+", text_lower))
|
||||
has_logic = (
|
||||
"firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower
|
||||
)
|
||||
has_style = (
|
||||
"gutachtenstil" in text_lower
|
||||
or "gutachten" in text_lower
|
||||
or "klausur" in text_lower
|
||||
)
|
||||
|
||||
for idx, adapter in enumerate(self.adapter_library):
|
||||
score = 0.1 # base prior
|
||||
for tag in adapter.domain_tags:
|
||||
if tag == "statute_recall" and has_statute:
|
||||
score += 0.8
|
||||
elif tag == "logic_reasoning" and has_logic:
|
||||
score += 0.8
|
||||
elif tag == "style_gutachtenstil" and has_style:
|
||||
score += 0.8
|
||||
priors[idx] = score
|
||||
|
||||
# Softmax normalize
|
||||
return torch.softmax(priors, dim=0)
|
||||
|
||||
def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None:
|
||||
"""Explicitly sets the routing weights for the next forward pass."""
|
||||
self.current_gate_weights = fuzzy_priors
|
||||
|
||||
def register_hooks(self) -> None:
|
||||
"""Attaches forward hooks to linear layers present in the adapter library."""
|
||||
self.unregister_hooks()
|
||||
|
||||
# Find all layers in the base model that have adapters
|
||||
adapter_layers: set[str] = set()
|
||||
for adapter in self.adapter_library:
|
||||
adapter_layers.update(adapter.layers.keys())
|
||||
|
||||
# Bind hooks dynamically
|
||||
for name, module in self.base_model.named_modules():
|
||||
# Check if this specific module has an adapter
|
||||
# We match using suffix to support model wrapping/prefixes
|
||||
matching_adapter_name = None
|
||||
for layer_name in adapter_layers:
|
||||
clean_layer_name = layer_name.replace(".weight", "").replace(
|
||||
".bias", ""
|
||||
)
|
||||
if name.endswith(clean_layer_name) or name == clean_layer_name:
|
||||
matching_adapter_name = layer_name
|
||||
break
|
||||
|
||||
if matching_adapter_name and isinstance(module, nn.Linear):
|
||||
hook = module.register_forward_hook(
|
||||
self._make_hook_fn(matching_adapter_name)
|
||||
)
|
||||
self.hooks.append(hook)
|
||||
|
||||
def unregister_hooks(self) -> None:
|
||||
"""Removes all registered hooks from the base model."""
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
self.hooks.clear()
|
||||
|
||||
def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]:
|
||||
"""Creates the hook function for a specific linear layer."""
|
||||
|
||||
def hook_fn(
|
||||
module: nn.Module,
|
||||
input_tensor: Tuple[torch.Tensor, ...],
|
||||
output_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = input_tensor[0] # [batch, seq_len, in_features]
|
||||
|
||||
# Calculate gate weights
|
||||
if self.current_gate_weights is not None:
|
||||
# Use manually set priors (e.g. fuzzy text-based)
|
||||
# Expand to match batch size
|
||||
batch_size = x.shape[0]
|
||||
weights = (
|
||||
self.current_gate_weights.to(x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, -1)
|
||||
)
|
||||
else:
|
||||
# Compute dynamically per token via learnable gate
|
||||
# We pool over sequence length or route per token
|
||||
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
|
||||
gate_logits = self.gate(x)
|
||||
weights = torch.softmax(gate_logits, dim=-1)
|
||||
|
||||
# Compute combined low-rank contribution
|
||||
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
|
||||
adapter_output = torch.zeros_like(output_tensor)
|
||||
|
||||
for i, adapter in enumerate(self.adapter_library):
|
||||
if layer_name in adapter.layers:
|
||||
lm = adapter.layers[layer_name]
|
||||
# Ensure tensors are on the correct device
|
||||
lora_A = lm.lora_A.to(x.device)
|
||||
lora_B = lm.lora_B.to(x.device)
|
||||
|
||||
# Dynamic scaling: gate_weight for this adapter
|
||||
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
|
||||
if len(weights.shape) == 3:
|
||||
# Token-level routing: shape [batch, seq_len, 1]
|
||||
g = weights[..., i : i + 1]
|
||||
else:
|
||||
# Batch-level routing: shape [batch, 1, 1]
|
||||
g = weights[:, i].view(-1, 1, 1)
|
||||
|
||||
# Low-rank projection
|
||||
x_proj = torch.matmul(x, lora_A.t())
|
||||
y_proj = torch.matmul(x_proj, lora_B.t())
|
||||
|
||||
# Accumulate scaled delta
|
||||
adapter_output += g * y_proj
|
||||
|
||||
return output_tensor + adapter_output
|
||||
|
||||
return hook_fn
|
||||
Reference in New Issue
Block a user