Files
FCES-native/python/expert_manifold_alignment.py

218 lines
7.5 KiB
Python

"""Expert Manifold Alignment for Parasitic QLoRA.
Aligns and profiles extracted LoRA adapters with optimizer update trajectories
and functional layer localization to classify them into legal exam domains:
- Statute Recall (embed, early layers)
- Logical Reasoning (attn, mid-late layers)
- Style / Gutachtenstil (mlp, late layers)
"""
from __future__ import annotations
import re
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter, LoRAMatrices
class ExpertManifoldAligner:
"""Aligns and profiles extracted LoRA adapters with the FCES expert manifold.
Uses layer-wise functional localization (depth and layer type) to categorize
adapters into legal reasoning domains (Logic, Statute, Style). Also computes
step-wise trajectory alignment with optimizer updates.
"""
def __init__(self, model: nn.Module) -> None:
self.total_layers = self._detect_total_layers(model)
self.prev_weights: Dict[str, torch.Tensor] = {}
# Prime the tracker with current model weights
self.track_step(model)
def _detect_total_layers(self, model: nn.Module) -> int:
"""Detects the number of transformer layers dynamically from parameter names."""
max_layer_idx = 0
found = False
for name in model.state_dict().keys():
match = re.search(r"layers?\.(\d+)\.", name.lower())
if match:
max_layer_idx = max(max_layer_idx, int(match.group(1)))
found = True
return max_layer_idx + 1 if found else 12
def track_step(self, model: nn.Module) -> Dict[str, torch.Tensor]:
"""Calculates step update δW_t = W_t - W_{t-1} for tracked linear layers."""
step_updates: Dict[str, torch.Tensor] = {}
for name, param in model.named_parameters():
if not param.requires_grad or len(param.shape) < 2:
continue
if name in self.prev_weights:
# δW = W_t - W_{t-1}
step_updates[name] = param.data - self.prev_weights[name]
self.prev_weights[name] = param.data.clone().detach()
return step_updates
def compute_subspace_alignment(
self, lora_matrices: LoRAMatrices, step_update: torch.Tensor
) -> float:
"""Computes cosine similarity between step update and LoRA subspace.
Uses O(r d k) trace formulation: trace(B^T * X * A^T) / (||BA||_F * ||X||_F)
to avoid large matrix allocations.
"""
B = lora_matrices.lora_B
A = lora_matrices.lora_A
X = step_update.to(B.device)
# Ensure matching shapes
if B.shape[0] != X.shape[0] or A.shape[1] != X.shape[1]:
return 0.0
# 1. Compute ||BA||_F using trace((B^T B)(A A^T))
BtB = torch.matmul(B.t(), B)
AAt = torch.matmul(A, A.t())
norm_BA_sq = torch.sum(BtB * AAt)
norm_BA: float = float(torch.sqrt(norm_BA_sq.clamp(min=1e-10)).item())
# 2. Compute ||X||_F
norm_X: float = float(torch.norm(X, p="fro").item())
if norm_X < 1e-10 or norm_BA < 1e-10:
return 0.0
# 3. Compute trace(B^T X A^T) = sum( (B^T X) * A )
BtX = torch.matmul(B.t(), X)
dot_product: float = float(torch.sum(BtX * A).item())
return dot_product / (norm_BA * norm_X)
def analyze_layer_profile(self, name: str) -> Tuple[str, str]:
"""Categorizes a layer by its type (embed, attn, mlp) and depth (early, mid, late)."""
nl = name.lower()
# Determine layer type
if "embed" in nl or "wte" in nl:
l_type = "embed"
elif any(
x in nl
for x in [
"attn",
"query",
"key",
"value",
"q_proj",
"k_proj",
"v_proj",
"out_proj",
]
):
l_type = "attn"
elif any(
x in nl
for x in [
"mlp",
"dense_h_to_4h",
"dense_4h_to_h",
"gate_proj",
"up_proj",
"down_proj",
]
):
l_type = "mlp"
else:
l_type = "other"
# Determine layer depth
match = re.search(r"layers?\.(\d+)\.", nl)
if match:
idx = int(match.group(1))
early_bound = self.total_layers // 3
late_bound = 2 * (self.total_layers // 3)
if idx < early_bound:
depth = "early"
elif idx < late_bound:
depth = "mid"
else:
depth = "late"
else:
# Fallback for non-indexed layers
if "embed" in nl:
depth = "early"
elif "lm_head" in nl or "head" in nl:
depth = "late"
else:
depth = "mid"
return l_type, depth
def profile_adapter(self, adapter: ExpertAdapter) -> Dict[str, float]:
"""Calculates domain alignment scores for the adapter.
Returns similarity coefficients for:
- statute_recall (embed/early layers)
- logic_reasoning (attn/mid-late layers)
- style_gutachtenstil (mlp/late layers)
"""
energies = {"statute": 0.0, "logic": 0.0, "style": 0.0}
total_energy = 0.0
for name, lm in adapter.layers.items():
l_type, depth = self.analyze_layer_profile(name)
# Energy of this layer's delta is sum of squares of singular values (Frobenius norm squared)
layer_energy = torch.sum(lm.singular_values**2).item()
total_energy += layer_energy
# Map layers to domain profiles
if depth == "early" or l_type == "embed":
energies["statute"] += layer_energy * 1.5
energies["logic"] += layer_energy * 0.5
elif depth == "mid":
if l_type == "attn":
energies["logic"] += layer_energy * 1.5
else: # mlp
energies["style"] += layer_energy * 1.2
energies["logic"] += layer_energy * 0.8
else: # late
if l_type == "mlp":
energies["style"] += layer_energy * 1.8
else:
energies["logic"] += layer_energy * 1.2
energies["style"] += layer_energy * 0.5
if total_energy < 1e-10:
return {
"statute_recall": 0.33,
"logic_reasoning": 0.33,
"style_gutachtenstil": 0.33,
}
sum_scores = sum(energies.values())
if sum_scores > 0:
scores = {k: v / sum_scores for k, v in energies.items()}
else:
scores = {"statute": 0.33, "logic": 0.33, "style": 0.33}
return {
"statute_recall": scores["statute"],
"logic_reasoning": scores["logic"],
"style_gutachtenstil": scores["style"],
}
def tag_adapter(self, adapter: ExpertAdapter) -> List[str]:
"""Profiles the adapter and adds the best domain tags to its domain_tags."""
scores = self.profile_adapter(adapter)
# Find dominant domains (above 35% concentration)
tags = []
for domain, score in scores.items():
if score >= 0.35:
tags.append(domain)
if not tags:
best_domain = max(scores, key=lambda k: scores[k])
tags.append(best_domain)
adapter.domain_tags = tags
return tags