feat: expert manifold alignment, MoE router, FCES controller metadata bindings

This commit is contained in:
AI-anonymous
2026-05-20 16:07:36 +02:00
parent 7e2e86d98c
commit 663e2fb71d
9 changed files with 727 additions and 19 deletions

View File

@@ -0,0 +1,181 @@
"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA.
Hooks into a base model to dynamically route hidden states through multiple
extracted expert adapters using fuzzy text-based domain priors or dynamic
learnable token-level gates.
"""
from __future__ import annotations
import re
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter
class LearnableGate(nn.Module): # type: ignore[misc]
"""A lightweight learnable MLP gating network for routing."""
def __init__(self, in_features: int, num_adapters: int) -> None:
super().__init__()
self.gate = nn.Sequential(
nn.Linear(in_features, in_features // 2),
nn.ReLU(),
nn.Linear(in_features // 2, num_adapters),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Input: [..., in_features]
# Output: [..., num_adapters] (unnormalized scores)
return self.gate(x)
class ExpertAdapterRouter:
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
def __init__(
self,
base_model: nn.Module,
adapter_library: List[ExpertAdapter],
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
) -> None:
self.base_model = base_model
self.adapter_library = adapter_library
self.num_adapters = len(adapter_library)
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
# Learnable gating network
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
)
# Active weights for current forward pass (batch size × num_adapters)
self.current_gate_weights: Optional[torch.Tensor] = None
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
"""Determines static routing priors based on keyword matching in input text."""
priors = torch.zeros(self.num_adapters)
# Heuristics for legal exam domains
text_lower = text.lower()
has_statute = bool(re.search(r"§\s*\d+", text_lower))
has_logic = (
"firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower
)
has_style = (
"gutachtenstil" in text_lower
or "gutachten" in text_lower
or "klausur" in text_lower
)
for idx, adapter in enumerate(self.adapter_library):
score = 0.1 # base prior
for tag in adapter.domain_tags:
if tag == "statute_recall" and has_statute:
score += 0.8
elif tag == "logic_reasoning" and has_logic:
score += 0.8
elif tag == "style_gutachtenstil" and has_style:
score += 0.8
priors[idx] = score
# Softmax normalize
return torch.softmax(priors, dim=0)
def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None:
"""Explicitly sets the routing weights for the next forward pass."""
self.current_gate_weights = fuzzy_priors
def register_hooks(self) -> None:
"""Attaches forward hooks to linear layers present in the adapter library."""
self.unregister_hooks()
# Find all layers in the base model that have adapters
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
# Bind hooks dynamically
for name, module in self.base_model.named_modules():
# Check if this specific module has an adapter
# We match using suffix to support model wrapping/prefixes
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
if matching_adapter_name and isinstance(module, nn.Linear):
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
)
self.hooks.append(hook)
def unregister_hooks(self) -> None:
"""Removes all registered hooks from the base model."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]:
"""Creates the hook function for a specific linear layer."""
def hook_fn(
module: nn.Module,
input_tensor: Tuple[torch.Tensor, ...],
output_tensor: torch.Tensor,
) -> torch.Tensor:
x = input_tensor[0] # [batch, seq_len, in_features]
# Calculate gate weights
if self.current_gate_weights is not None:
# Use manually set priors (e.g. fuzzy text-based)
# Expand to match batch size
batch_size = x.shape[0]
weights = (
self.current_gate_weights.to(x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
else:
# Compute dynamically per token via learnable gate
# We pool over sequence length or route per token
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
# Compute combined low-rank contribution
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
adapter_output = torch.zeros_like(output_tensor)
for i, adapter in enumerate(self.adapter_library):
if layer_name in adapter.layers:
lm = adapter.layers[layer_name]
# Ensure tensors are on the correct device
lora_A = lm.lora_A.to(x.device)
lora_B = lm.lora_B.to(x.device)
# Dynamic scaling: gate_weight for this adapter
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
if len(weights.shape) == 3:
# Token-level routing: shape [batch, seq_len, 1]
g = weights[..., i : i + 1]
else:
# Batch-level routing: shape [batch, 1, 1]
g = weights[:, i].view(-1, 1, 1)
# Low-rank projection
x_proj = torch.matmul(x, lora_A.t())
y_proj = torch.matmul(x_proj, lora_B.t())
# Accumulate scaled delta
adapter_output += g * y_proj
return output_tensor + adapter_output
return hook_fn