feat: expert manifold alignment, MoE router, FCES controller metadata bindings

This commit is contained in:
AI-anonymous
2026-05-20 16:07:36 +02:00
parent 7e2e86d98c
commit 663e2fb71d
9 changed files with 727 additions and 19 deletions

View File

@@ -0,0 +1,181 @@
"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA.
Hooks into a base model to dynamically route hidden states through multiple
extracted expert adapters using fuzzy text-based domain priors or dynamic
learnable token-level gates.
"""
from __future__ import annotations
import re
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter
class LearnableGate(nn.Module): # type: ignore[misc]
"""A lightweight learnable MLP gating network for routing."""
def __init__(self, in_features: int, num_adapters: int) -> None:
super().__init__()
self.gate = nn.Sequential(
nn.Linear(in_features, in_features // 2),
nn.ReLU(),
nn.Linear(in_features // 2, num_adapters),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Input: [..., in_features]
# Output: [..., num_adapters] (unnormalized scores)
return self.gate(x)
class ExpertAdapterRouter:
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
def __init__(
self,
base_model: nn.Module,
adapter_library: List[ExpertAdapter],
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
) -> None:
self.base_model = base_model
self.adapter_library = adapter_library
self.num_adapters = len(adapter_library)
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
# Learnable gating network
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
)
# Active weights for current forward pass (batch size × num_adapters)
self.current_gate_weights: Optional[torch.Tensor] = None
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
"""Determines static routing priors based on keyword matching in input text."""
priors = torch.zeros(self.num_adapters)
# Heuristics for legal exam domains
text_lower = text.lower()
has_statute = bool(re.search(r"§\s*\d+", text_lower))
has_logic = (
"firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower
)
has_style = (
"gutachtenstil" in text_lower
or "gutachten" in text_lower
or "klausur" in text_lower
)
for idx, adapter in enumerate(self.adapter_library):
score = 0.1 # base prior
for tag in adapter.domain_tags:
if tag == "statute_recall" and has_statute:
score += 0.8
elif tag == "logic_reasoning" and has_logic:
score += 0.8
elif tag == "style_gutachtenstil" and has_style:
score += 0.8
priors[idx] = score
# Softmax normalize
return torch.softmax(priors, dim=0)
def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None:
"""Explicitly sets the routing weights for the next forward pass."""
self.current_gate_weights = fuzzy_priors
def register_hooks(self) -> None:
"""Attaches forward hooks to linear layers present in the adapter library."""
self.unregister_hooks()
# Find all layers in the base model that have adapters
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
# Bind hooks dynamically
for name, module in self.base_model.named_modules():
# Check if this specific module has an adapter
# We match using suffix to support model wrapping/prefixes
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
if matching_adapter_name and isinstance(module, nn.Linear):
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
)
self.hooks.append(hook)
def unregister_hooks(self) -> None:
"""Removes all registered hooks from the base model."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]:
"""Creates the hook function for a specific linear layer."""
def hook_fn(
module: nn.Module,
input_tensor: Tuple[torch.Tensor, ...],
output_tensor: torch.Tensor,
) -> torch.Tensor:
x = input_tensor[0] # [batch, seq_len, in_features]
# Calculate gate weights
if self.current_gate_weights is not None:
# Use manually set priors (e.g. fuzzy text-based)
# Expand to match batch size
batch_size = x.shape[0]
weights = (
self.current_gate_weights.to(x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
else:
# Compute dynamically per token via learnable gate
# We pool over sequence length or route per token
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
# Compute combined low-rank contribution
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
adapter_output = torch.zeros_like(output_tensor)
for i, adapter in enumerate(self.adapter_library):
if layer_name in adapter.layers:
lm = adapter.layers[layer_name]
# Ensure tensors are on the correct device
lora_A = lm.lora_A.to(x.device)
lora_B = lm.lora_B.to(x.device)
# Dynamic scaling: gate_weight for this adapter
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
if len(weights.shape) == 3:
# Token-level routing: shape [batch, seq_len, 1]
g = weights[..., i : i + 1]
else:
# Batch-level routing: shape [batch, 1, 1]
g = weights[:, i].view(-1, 1, 1)
# Low-rank projection
x_proj = torch.matmul(x, lora_A.t())
y_proj = torch.matmul(x_proj, lora_B.t())
# Accumulate scaled delta
adapter_output += g * y_proj
return output_tensor + adapter_output
return hook_fn

View File

@@ -0,0 +1,217 @@
"""Expert Manifold Alignment for Parasitic QLoRA.
Aligns and profiles extracted LoRA adapters with optimizer update trajectories
and functional layer localization to classify them into legal exam domains:
- Statute Recall (embed, early layers)
- Logical Reasoning (attn, mid-late layers)
- Style / Gutachtenstil (mlp, late layers)
"""
from __future__ import annotations
import re
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter, LoRAMatrices
class ExpertManifoldAligner:
"""Aligns and profiles extracted LoRA adapters with the FCES expert manifold.
Uses layer-wise functional localization (depth and layer type) to categorize
adapters into legal reasoning domains (Logic, Statute, Style). Also computes
step-wise trajectory alignment with optimizer updates.
"""
def __init__(self, model: nn.Module) -> None:
self.total_layers = self._detect_total_layers(model)
self.prev_weights: Dict[str, torch.Tensor] = {}
# Prime the tracker with current model weights
self.track_step(model)
def _detect_total_layers(self, model: nn.Module) -> int:
"""Detects the number of transformer layers dynamically from parameter names."""
max_layer_idx = 0
found = False
for name in model.state_dict().keys():
match = re.search(r"layers?\.(\d+)\.", name.lower())
if match:
max_layer_idx = max(max_layer_idx, int(match.group(1)))
found = True
return max_layer_idx + 1 if found else 12
def track_step(self, model: nn.Module) -> Dict[str, torch.Tensor]:
"""Calculates step update δW_t = W_t - W_{t-1} for tracked linear layers."""
step_updates: Dict[str, torch.Tensor] = {}
for name, param in model.named_parameters():
if not param.requires_grad or len(param.shape) < 2:
continue
if name in self.prev_weights:
# δW = W_t - W_{t-1}
step_updates[name] = param.data - self.prev_weights[name]
self.prev_weights[name] = param.data.clone().detach()
return step_updates
def compute_subspace_alignment(
self, lora_matrices: LoRAMatrices, step_update: torch.Tensor
) -> float:
"""Computes cosine similarity between step update and LoRA subspace.
Uses O(r d k) trace formulation: trace(B^T * X * A^T) / (||BA||_F * ||X||_F)
to avoid large matrix allocations.
"""
B = lora_matrices.lora_B
A = lora_matrices.lora_A
X = step_update.to(B.device)
# Ensure matching shapes
if B.shape[0] != X.shape[0] or A.shape[1] != X.shape[1]:
return 0.0
# 1. Compute ||BA||_F using trace((B^T B)(A A^T))
BtB = torch.matmul(B.t(), B)
AAt = torch.matmul(A, A.t())
norm_BA_sq = torch.sum(BtB * AAt)
norm_BA: float = float(torch.sqrt(norm_BA_sq.clamp(min=1e-10)).item())
# 2. Compute ||X||_F
norm_X: float = float(torch.norm(X, p="fro").item())
if norm_X < 1e-10 or norm_BA < 1e-10:
return 0.0
# 3. Compute trace(B^T X A^T) = sum( (B^T X) * A )
BtX = torch.matmul(B.t(), X)
dot_product: float = float(torch.sum(BtX * A).item())
return dot_product / (norm_BA * norm_X)
def analyze_layer_profile(self, name: str) -> Tuple[str, str]:
"""Categorizes a layer by its type (embed, attn, mlp) and depth (early, mid, late)."""
nl = name.lower()
# Determine layer type
if "embed" in nl or "wte" in nl:
l_type = "embed"
elif any(
x in nl
for x in [
"attn",
"query",
"key",
"value",
"q_proj",
"k_proj",
"v_proj",
"out_proj",
]
):
l_type = "attn"
elif any(
x in nl
for x in [
"mlp",
"dense_h_to_4h",
"dense_4h_to_h",
"gate_proj",
"up_proj",
"down_proj",
]
):
l_type = "mlp"
else:
l_type = "other"
# Determine layer depth
match = re.search(r"layers?\.(\d+)\.", nl)
if match:
idx = int(match.group(1))
early_bound = self.total_layers // 3
late_bound = 2 * (self.total_layers // 3)
if idx < early_bound:
depth = "early"
elif idx < late_bound:
depth = "mid"
else:
depth = "late"
else:
# Fallback for non-indexed layers
if "embed" in nl:
depth = "early"
elif "lm_head" in nl or "head" in nl:
depth = "late"
else:
depth = "mid"
return l_type, depth
def profile_adapter(self, adapter: ExpertAdapter) -> Dict[str, float]:
"""Calculates domain alignment scores for the adapter.
Returns similarity coefficients for:
- statute_recall (embed/early layers)
- logic_reasoning (attn/mid-late layers)
- style_gutachtenstil (mlp/late layers)
"""
energies = {"statute": 0.0, "logic": 0.0, "style": 0.0}
total_energy = 0.0
for name, lm in adapter.layers.items():
l_type, depth = self.analyze_layer_profile(name)
# Energy of this layer's delta is sum of squares of singular values (Frobenius norm squared)
layer_energy = torch.sum(lm.singular_values**2).item()
total_energy += layer_energy
# Map layers to domain profiles
if depth == "early" or l_type == "embed":
energies["statute"] += layer_energy * 1.5
energies["logic"] += layer_energy * 0.5
elif depth == "mid":
if l_type == "attn":
energies["logic"] += layer_energy * 1.5
else: # mlp
energies["style"] += layer_energy * 1.2
energies["logic"] += layer_energy * 0.8
else: # late
if l_type == "mlp":
energies["style"] += layer_energy * 1.8
else:
energies["logic"] += layer_energy * 1.2
energies["style"] += layer_energy * 0.5
if total_energy < 1e-10:
return {
"statute_recall": 0.33,
"logic_reasoning": 0.33,
"style_gutachtenstil": 0.33,
}
sum_scores = sum(energies.values())
if sum_scores > 0:
scores = {k: v / sum_scores for k, v in energies.items()}
else:
scores = {"statute": 0.33, "logic": 0.33, "style": 0.33}
return {
"statute_recall": scores["statute"],
"logic_reasoning": scores["logic"],
"style_gutachtenstil": scores["style"],
}
def tag_adapter(self, adapter: ExpertAdapter) -> List[str]:
"""Profiles the adapter and adds the best domain tags to its domain_tags."""
scores = self.profile_adapter(adapter)
# Find dominant domains (above 35% concentration)
tags = []
for domain, score in scores.items():
if score >= 0.35:
tags.append(domain)
if not tags:
best_domain = max(scores, key=lambda k: scores[k])
tags.append(best_domain)
adapter.domain_tags = tags
return tags

View File

@@ -40,6 +40,10 @@ PYBIND11_MODULE(fces_native, m) {
.def("restore_from_ram", &fces::FCESOptimizer::restore_from_ram)
.def("step_count", &fces::FCESOptimizer::step_count)
.def("calculate_sparsity", &fces::FCESOptimizer::calculate_sparsity)
.def("get_active_controller_id",
&fces::FCESOptimizer::get_active_controller_id)
.def("get_active_controller_fitness",
&fces::FCESOptimizer::get_active_controller_fitness)
.def("zero_grad", [](fces::FCESOptimizer &self) {
for (auto &group : self.param_groups()) {
for (auto &p : group.params()) {