feat: expert manifold alignment, MoE router, FCES controller metadata bindings
This commit is contained in:
217
python/expert_manifold_alignment.py
Normal file
217
python/expert_manifold_alignment.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Expert Manifold Alignment for Parasitic QLoRA.
|
||||
|
||||
Aligns and profiles extracted LoRA adapters with optimizer update trajectories
|
||||
and functional layer localization to classify them into legal exam domains:
|
||||
- Statute Recall (embed, early layers)
|
||||
- Logical Reasoning (attn, mid-late layers)
|
||||
- Style / Gutachtenstil (mlp, late layers)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from parasitic_qlora import ExpertAdapter, LoRAMatrices
|
||||
|
||||
|
||||
class ExpertManifoldAligner:
|
||||
"""Aligns and profiles extracted LoRA adapters with the FCES expert manifold.
|
||||
|
||||
Uses layer-wise functional localization (depth and layer type) to categorize
|
||||
adapters into legal reasoning domains (Logic, Statute, Style). Also computes
|
||||
step-wise trajectory alignment with optimizer updates.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module) -> None:
|
||||
self.total_layers = self._detect_total_layers(model)
|
||||
self.prev_weights: Dict[str, torch.Tensor] = {}
|
||||
# Prime the tracker with current model weights
|
||||
self.track_step(model)
|
||||
|
||||
def _detect_total_layers(self, model: nn.Module) -> int:
|
||||
"""Detects the number of transformer layers dynamically from parameter names."""
|
||||
max_layer_idx = 0
|
||||
found = False
|
||||
for name in model.state_dict().keys():
|
||||
match = re.search(r"layers?\.(\d+)\.", name.lower())
|
||||
if match:
|
||||
max_layer_idx = max(max_layer_idx, int(match.group(1)))
|
||||
found = True
|
||||
return max_layer_idx + 1 if found else 12
|
||||
|
||||
def track_step(self, model: nn.Module) -> Dict[str, torch.Tensor]:
|
||||
"""Calculates step update δW_t = W_t - W_{t-1} for tracked linear layers."""
|
||||
step_updates: Dict[str, torch.Tensor] = {}
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad or len(param.shape) < 2:
|
||||
continue
|
||||
if name in self.prev_weights:
|
||||
# δW = W_t - W_{t-1}
|
||||
step_updates[name] = param.data - self.prev_weights[name]
|
||||
self.prev_weights[name] = param.data.clone().detach()
|
||||
return step_updates
|
||||
|
||||
def compute_subspace_alignment(
|
||||
self, lora_matrices: LoRAMatrices, step_update: torch.Tensor
|
||||
) -> float:
|
||||
"""Computes cosine similarity between step update and LoRA subspace.
|
||||
|
||||
Uses O(r d k) trace formulation: trace(B^T * X * A^T) / (||BA||_F * ||X||_F)
|
||||
to avoid large matrix allocations.
|
||||
"""
|
||||
B = lora_matrices.lora_B
|
||||
A = lora_matrices.lora_A
|
||||
X = step_update.to(B.device)
|
||||
|
||||
# Ensure matching shapes
|
||||
if B.shape[0] != X.shape[0] or A.shape[1] != X.shape[1]:
|
||||
return 0.0
|
||||
|
||||
# 1. Compute ||BA||_F using trace((B^T B)(A A^T))
|
||||
BtB = torch.matmul(B.t(), B)
|
||||
AAt = torch.matmul(A, A.t())
|
||||
norm_BA_sq = torch.sum(BtB * AAt)
|
||||
norm_BA: float = float(torch.sqrt(norm_BA_sq.clamp(min=1e-10)).item())
|
||||
|
||||
# 2. Compute ||X||_F
|
||||
norm_X: float = float(torch.norm(X, p="fro").item())
|
||||
if norm_X < 1e-10 or norm_BA < 1e-10:
|
||||
return 0.0
|
||||
|
||||
# 3. Compute trace(B^T X A^T) = sum( (B^T X) * A )
|
||||
BtX = torch.matmul(B.t(), X)
|
||||
dot_product: float = float(torch.sum(BtX * A).item())
|
||||
|
||||
return dot_product / (norm_BA * norm_X)
|
||||
|
||||
def analyze_layer_profile(self, name: str) -> Tuple[str, str]:
|
||||
"""Categorizes a layer by its type (embed, attn, mlp) and depth (early, mid, late)."""
|
||||
nl = name.lower()
|
||||
|
||||
# Determine layer type
|
||||
if "embed" in nl or "wte" in nl:
|
||||
l_type = "embed"
|
||||
elif any(
|
||||
x in nl
|
||||
for x in [
|
||||
"attn",
|
||||
"query",
|
||||
"key",
|
||||
"value",
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"out_proj",
|
||||
]
|
||||
):
|
||||
l_type = "attn"
|
||||
elif any(
|
||||
x in nl
|
||||
for x in [
|
||||
"mlp",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
):
|
||||
l_type = "mlp"
|
||||
else:
|
||||
l_type = "other"
|
||||
|
||||
# Determine layer depth
|
||||
match = re.search(r"layers?\.(\d+)\.", nl)
|
||||
if match:
|
||||
idx = int(match.group(1))
|
||||
early_bound = self.total_layers // 3
|
||||
late_bound = 2 * (self.total_layers // 3)
|
||||
if idx < early_bound:
|
||||
depth = "early"
|
||||
elif idx < late_bound:
|
||||
depth = "mid"
|
||||
else:
|
||||
depth = "late"
|
||||
else:
|
||||
# Fallback for non-indexed layers
|
||||
if "embed" in nl:
|
||||
depth = "early"
|
||||
elif "lm_head" in nl or "head" in nl:
|
||||
depth = "late"
|
||||
else:
|
||||
depth = "mid"
|
||||
|
||||
return l_type, depth
|
||||
|
||||
def profile_adapter(self, adapter: ExpertAdapter) -> Dict[str, float]:
|
||||
"""Calculates domain alignment scores for the adapter.
|
||||
|
||||
Returns similarity coefficients for:
|
||||
- statute_recall (embed/early layers)
|
||||
- logic_reasoning (attn/mid-late layers)
|
||||
- style_gutachtenstil (mlp/late layers)
|
||||
"""
|
||||
energies = {"statute": 0.0, "logic": 0.0, "style": 0.0}
|
||||
total_energy = 0.0
|
||||
|
||||
for name, lm in adapter.layers.items():
|
||||
l_type, depth = self.analyze_layer_profile(name)
|
||||
# Energy of this layer's delta is sum of squares of singular values (Frobenius norm squared)
|
||||
layer_energy = torch.sum(lm.singular_values**2).item()
|
||||
total_energy += layer_energy
|
||||
|
||||
# Map layers to domain profiles
|
||||
if depth == "early" or l_type == "embed":
|
||||
energies["statute"] += layer_energy * 1.5
|
||||
energies["logic"] += layer_energy * 0.5
|
||||
elif depth == "mid":
|
||||
if l_type == "attn":
|
||||
energies["logic"] += layer_energy * 1.5
|
||||
else: # mlp
|
||||
energies["style"] += layer_energy * 1.2
|
||||
energies["logic"] += layer_energy * 0.8
|
||||
else: # late
|
||||
if l_type == "mlp":
|
||||
energies["style"] += layer_energy * 1.8
|
||||
else:
|
||||
energies["logic"] += layer_energy * 1.2
|
||||
energies["style"] += layer_energy * 0.5
|
||||
|
||||
if total_energy < 1e-10:
|
||||
return {
|
||||
"statute_recall": 0.33,
|
||||
"logic_reasoning": 0.33,
|
||||
"style_gutachtenstil": 0.33,
|
||||
}
|
||||
|
||||
sum_scores = sum(energies.values())
|
||||
if sum_scores > 0:
|
||||
scores = {k: v / sum_scores for k, v in energies.items()}
|
||||
else:
|
||||
scores = {"statute": 0.33, "logic": 0.33, "style": 0.33}
|
||||
|
||||
return {
|
||||
"statute_recall": scores["statute"],
|
||||
"logic_reasoning": scores["logic"],
|
||||
"style_gutachtenstil": scores["style"],
|
||||
}
|
||||
|
||||
def tag_adapter(self, adapter: ExpertAdapter) -> List[str]:
|
||||
"""Profiles the adapter and adds the best domain tags to its domain_tags."""
|
||||
scores = self.profile_adapter(adapter)
|
||||
|
||||
# Find dominant domains (above 35% concentration)
|
||||
tags = []
|
||||
for domain, score in scores.items():
|
||||
if score >= 0.35:
|
||||
tags.append(domain)
|
||||
|
||||
if not tags:
|
||||
best_domain = max(scores, key=lambda k: scores[k])
|
||||
tags.append(best_domain)
|
||||
|
||||
adapter.domain_tags = tags
|
||||
return tags
|
||||
Reference in New Issue
Block a user