"""Expert Manifold Alignment for Parasitic QLoRA. Aligns and profiles extracted LoRA adapters with optimizer update trajectories and functional layer localization to classify them into legal exam domains: - Statute Recall (embed, early layers) - Logical Reasoning (attn, mid-late layers) - Style / Gutachtenstil (mlp, late layers) """ from __future__ import annotations import re from typing import Dict, List, Tuple import torch import torch.nn as nn from parasitic_qlora import ExpertAdapter, LoRAMatrices class ExpertManifoldAligner: """Aligns and profiles extracted LoRA adapters with the FCES expert manifold. Uses layer-wise functional localization (depth and layer type) to categorize adapters into legal reasoning domains (Logic, Statute, Style). Also computes step-wise trajectory alignment with optimizer updates. """ def __init__(self, model: nn.Module) -> None: self.total_layers = self._detect_total_layers(model) self.prev_weights: Dict[str, torch.Tensor] = {} # Prime the tracker with current model weights self.track_step(model) def _detect_total_layers(self, model: nn.Module) -> int: """Detects the number of transformer layers dynamically from parameter names.""" max_layer_idx = 0 found = False for name in model.state_dict().keys(): match = re.search(r"layers?\.(\d+)\.", name.lower()) if match: max_layer_idx = max(max_layer_idx, int(match.group(1))) found = True return max_layer_idx + 1 if found else 12 def track_step(self, model: nn.Module) -> Dict[str, torch.Tensor]: """Calculates step update δW_t = W_t - W_{t-1} for tracked linear layers.""" step_updates: Dict[str, torch.Tensor] = {} for name, param in model.named_parameters(): if not param.requires_grad or len(param.shape) < 2: continue if name in self.prev_weights: # δW = W_t - W_{t-1} step_updates[name] = param.data - self.prev_weights[name] self.prev_weights[name] = param.data.clone().detach() return step_updates def compute_subspace_alignment( self, lora_matrices: LoRAMatrices, step_update: torch.Tensor ) -> float: """Computes cosine similarity between step update and LoRA subspace. Uses O(r d k) trace formulation: trace(B^T * X * A^T) / (||BA||_F * ||X||_F) to avoid large matrix allocations. """ B = lora_matrices.lora_B A = lora_matrices.lora_A X = step_update.to(B.device) # Ensure matching shapes if B.shape[0] != X.shape[0] or A.shape[1] != X.shape[1]: return 0.0 # 1. Compute ||BA||_F using trace((B^T B)(A A^T)) BtB = torch.matmul(B.t(), B) AAt = torch.matmul(A, A.t()) norm_BA_sq = torch.sum(BtB * AAt) norm_BA: float = float(torch.sqrt(norm_BA_sq.clamp(min=1e-10)).item()) # 2. Compute ||X||_F norm_X: float = float(torch.norm(X, p="fro").item()) if norm_X < 1e-10 or norm_BA < 1e-10: return 0.0 # 3. Compute trace(B^T X A^T) = sum( (B^T X) * A ) BtX = torch.matmul(B.t(), X) dot_product: float = float(torch.sum(BtX * A).item()) return dot_product / (norm_BA * norm_X) def analyze_layer_profile(self, name: str) -> Tuple[str, str]: """Categorizes a layer by its type (embed, attn, mlp) and depth (early, mid, late).""" nl = name.lower() # Determine layer type if "embed" in nl or "wte" in nl: l_type = "embed" elif any( x in nl for x in [ "attn", "query", "key", "value", "q_proj", "k_proj", "v_proj", "out_proj", ] ): l_type = "attn" elif any( x in nl for x in [ "mlp", "dense_h_to_4h", "dense_4h_to_h", "gate_proj", "up_proj", "down_proj", ] ): l_type = "mlp" else: l_type = "other" # Determine layer depth match = re.search(r"layers?\.(\d+)\.", nl) if match: idx = int(match.group(1)) early_bound = self.total_layers // 3 late_bound = 2 * (self.total_layers // 3) if idx < early_bound: depth = "early" elif idx < late_bound: depth = "mid" else: depth = "late" else: # Fallback for non-indexed layers if "embed" in nl: depth = "early" elif "lm_head" in nl or "head" in nl: depth = "late" else: depth = "mid" return l_type, depth def profile_adapter(self, adapter: ExpertAdapter) -> Dict[str, float]: """Calculates domain alignment scores for the adapter. Returns similarity coefficients for: - statute_recall (embed/early layers) - logic_reasoning (attn/mid-late layers) - style_gutachtenstil (mlp/late layers) """ energies = {"statute": 0.0, "logic": 0.0, "style": 0.0} total_energy = 0.0 for name, lm in adapter.layers.items(): l_type, depth = self.analyze_layer_profile(name) # Energy of this layer's delta is sum of squares of singular values (Frobenius norm squared) layer_energy = torch.sum(lm.singular_values**2).item() total_energy += layer_energy # Map layers to domain profiles if depth == "early" or l_type == "embed": energies["statute"] += layer_energy * 1.5 energies["logic"] += layer_energy * 0.5 elif depth == "mid": if l_type == "attn": energies["logic"] += layer_energy * 1.5 else: # mlp energies["style"] += layer_energy * 1.2 energies["logic"] += layer_energy * 0.8 else: # late if l_type == "mlp": energies["style"] += layer_energy * 1.8 else: energies["logic"] += layer_energy * 1.2 energies["style"] += layer_energy * 0.5 if total_energy < 1e-10: return { "statute_recall": 0.33, "logic_reasoning": 0.33, "style_gutachtenstil": 0.33, } sum_scores = sum(energies.values()) if sum_scores > 0: scores = {k: v / sum_scores for k, v in energies.items()} else: scores = {"statute": 0.33, "logic": 0.33, "style": 0.33} return { "statute_recall": scores["statute"], "logic_reasoning": scores["logic"], "style_gutachtenstil": scores["style"], } def tag_adapter(self, adapter: ExpertAdapter) -> List[str]: """Profiles the adapter and adds the best domain tags to its domain_tags.""" scores = self.profile_adapter(adapter) # Find dominant domains (above 35% concentration) tags = [] for domain, score in scores.items(): if score >= 0.35: tags.append(domain) if not tags: best_domain = max(scores, key=lambda k: scores[k]) tags.append(best_domain) adapter.domain_tags = tags return tags