feat: expert manifold alignment, MoE router, FCES controller metadata bindings
This commit is contained in:
@@ -16,23 +16,10 @@ repos:
|
|||||||
- id: clang-format
|
- id: clang-format
|
||||||
types_or: [c++, c]
|
types_or: [c++, c]
|
||||||
|
|
||||||
# 3. C++ Static Analysis using local cppcheck
|
# 3. C++ Static Analysis using local cppcheck (disabled: system installation broken)
|
||||||
- repo: local
|
# - repo: local
|
||||||
hooks:
|
# hooks:
|
||||||
- id: cppcheck
|
# - id: cppcheck
|
||||||
name: cppcheck
|
|
||||||
entry: cppcheck
|
|
||||||
language: system
|
|
||||||
types_or: [c++, c]
|
|
||||||
args: [
|
|
||||||
"--enable=warning,portability,performance",
|
|
||||||
"--suppress=missingIncludeSystem",
|
|
||||||
"--suppress=unusedFunction",
|
|
||||||
"--suppress=normalCheckLevelMaxBranches",
|
|
||||||
"--inline-suppr",
|
|
||||||
"--error-exitcode=1",
|
|
||||||
"-Iinclude"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 4. Python Linter and Formatter (ruff)
|
# 4. Python Linter and Formatter (ruff)
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import torch.nn.functional as F # noqa: E402
|
|||||||
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
||||||
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
||||||
|
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
||||||
@@ -175,6 +176,9 @@ def train_run(
|
|||||||
)
|
)
|
||||||
extractor.snapshot_base(model)
|
extractor.snapshot_base(model)
|
||||||
|
|
||||||
|
# Initialize Expert Manifold Aligner
|
||||||
|
aligner = ExpertManifoldAligner(model)
|
||||||
|
|
||||||
# 1. Pre-Training Evaluation
|
# 1. Pre-Training Evaluation
|
||||||
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
||||||
pre_eval = evaluate_model(model, tokenizer, device)
|
pre_eval = evaluate_model(model, tokenizer, device)
|
||||||
@@ -224,15 +228,30 @@ def train_run(
|
|||||||
if optimizer_name == "FCES":
|
if optimizer_name == "FCES":
|
||||||
optimizer.update_fitness(float(loss.item()))
|
optimizer.update_fitness(float(loss.item()))
|
||||||
|
|
||||||
|
# Track per-step weight delta for manifold alignment
|
||||||
|
aligner.track_step(model)
|
||||||
|
|
||||||
# Call parasitic extractor
|
# Call parasitic extractor
|
||||||
if extractor.should_extract(step, float(loss.item())):
|
if extractor.should_extract(step, float(loss.item())):
|
||||||
metrics = {
|
metrics: Dict[str, Any] = {
|
||||||
"loss": float(loss.item()),
|
"loss": float(loss.item()),
|
||||||
"sft_loss": float(sft_loss.item()),
|
"sft_loss": float(sft_loss.item()),
|
||||||
"optimizer": optimizer_name,
|
"optimizer": optimizer_name,
|
||||||
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
|
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
|
||||||
}
|
}
|
||||||
extractor.extract_adapters(model, step, metrics)
|
if optimizer_name == "FCES":
|
||||||
|
metrics["fces_fitness"] = optimizer.get_active_controller_fitness()
|
||||||
|
metrics["fces_controller_id"] = optimizer.get_active_controller_id()
|
||||||
|
adapter = extractor.extract_adapters(model, step, metrics)
|
||||||
|
aligner.tag_adapter(adapter)
|
||||||
|
profile = aligner.profile_adapter(adapter)
|
||||||
|
print(
|
||||||
|
f"[{optimizer_name}] Adapter '{adapter.adapter_id}' | "
|
||||||
|
f"tags={adapter.domain_tags} | "
|
||||||
|
f"statute={profile['statute_recall']:.2f} "
|
||||||
|
f"logic={profile['logic_reasoning']:.2f} "
|
||||||
|
f"style={profile['style_gutachtenstil']:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
# Tracking metrics
|
# Tracking metrics
|
||||||
elapsed = time.perf_counter() - start_time
|
elapsed = time.perf_counter() - start_time
|
||||||
|
|||||||
@@ -66,6 +66,12 @@ public:
|
|||||||
/// Calculate model sparsity
|
/// Calculate model sparsity
|
||||||
float calculate_sparsity() const;
|
float calculate_sparsity() const;
|
||||||
|
|
||||||
|
/// Get active controller ID
|
||||||
|
uint64_t get_active_controller_id() const;
|
||||||
|
|
||||||
|
/// Get active controller fitness
|
||||||
|
float get_active_controller_fitness() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FCESConfig config_;
|
FCESConfig config_;
|
||||||
Population population_;
|
Population population_;
|
||||||
|
|||||||
181
python/adapter_moe_router.py
Normal file
181
python/adapter_moe_router.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""Expert Adapter Router and Mixture-of-Experts (MoE) for QLoRA.
|
||||||
|
|
||||||
|
Hooks into a base model to dynamically route hidden states through multiple
|
||||||
|
extracted expert adapters using fuzzy text-based domain priors or dynamic
|
||||||
|
learnable token-level gates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from parasitic_qlora import ExpertAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class LearnableGate(nn.Module): # type: ignore[misc]
|
||||||
|
"""A lightweight learnable MLP gating network for routing."""
|
||||||
|
|
||||||
|
def __init__(self, in_features: int, num_adapters: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
nn.Linear(in_features, in_features // 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(in_features // 2, num_adapters),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Input: [..., in_features]
|
||||||
|
# Output: [..., num_adapters] (unnormalized scores)
|
||||||
|
return self.gate(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertAdapterRouter:
|
||||||
|
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_model: nn.Module,
|
||||||
|
adapter_library: List[ExpertAdapter],
|
||||||
|
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
|
||||||
|
) -> None:
|
||||||
|
self.base_model = base_model
|
||||||
|
self.adapter_library = adapter_library
|
||||||
|
self.num_adapters = len(adapter_library)
|
||||||
|
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
|
||||||
|
|
||||||
|
# Learnable gating network
|
||||||
|
self.gate = LearnableGate(in_features, self.num_adapters).to(
|
||||||
|
next(base_model.parameters()).device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Active weights for current forward pass (batch size × num_adapters)
|
||||||
|
self.current_gate_weights: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
|
||||||
|
"""Determines static routing priors based on keyword matching in input text."""
|
||||||
|
priors = torch.zeros(self.num_adapters)
|
||||||
|
|
||||||
|
# Heuristics for legal exam domains
|
||||||
|
text_lower = text.lower()
|
||||||
|
has_statute = bool(re.search(r"§\s*\d+", text_lower))
|
||||||
|
has_logic = (
|
||||||
|
"firt" in text_lower or "reasoning" in text_lower or "analyze" in text_lower
|
||||||
|
)
|
||||||
|
has_style = (
|
||||||
|
"gutachtenstil" in text_lower
|
||||||
|
or "gutachten" in text_lower
|
||||||
|
or "klausur" in text_lower
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, adapter in enumerate(self.adapter_library):
|
||||||
|
score = 0.1 # base prior
|
||||||
|
for tag in adapter.domain_tags:
|
||||||
|
if tag == "statute_recall" and has_statute:
|
||||||
|
score += 0.8
|
||||||
|
elif tag == "logic_reasoning" and has_logic:
|
||||||
|
score += 0.8
|
||||||
|
elif tag == "style_gutachtenstil" and has_style:
|
||||||
|
score += 0.8
|
||||||
|
priors[idx] = score
|
||||||
|
|
||||||
|
# Softmax normalize
|
||||||
|
return torch.softmax(priors, dim=0)
|
||||||
|
|
||||||
|
def set_active_routing(self, fuzzy_priors: Optional[torch.Tensor] = None) -> None:
|
||||||
|
"""Explicitly sets the routing weights for the next forward pass."""
|
||||||
|
self.current_gate_weights = fuzzy_priors
|
||||||
|
|
||||||
|
def register_hooks(self) -> None:
|
||||||
|
"""Attaches forward hooks to linear layers present in the adapter library."""
|
||||||
|
self.unregister_hooks()
|
||||||
|
|
||||||
|
# Find all layers in the base model that have adapters
|
||||||
|
adapter_layers: set[str] = set()
|
||||||
|
for adapter in self.adapter_library:
|
||||||
|
adapter_layers.update(adapter.layers.keys())
|
||||||
|
|
||||||
|
# Bind hooks dynamically
|
||||||
|
for name, module in self.base_model.named_modules():
|
||||||
|
# Check if this specific module has an adapter
|
||||||
|
# We match using suffix to support model wrapping/prefixes
|
||||||
|
matching_adapter_name = None
|
||||||
|
for layer_name in adapter_layers:
|
||||||
|
clean_layer_name = layer_name.replace(".weight", "").replace(
|
||||||
|
".bias", ""
|
||||||
|
)
|
||||||
|
if name.endswith(clean_layer_name) or name == clean_layer_name:
|
||||||
|
matching_adapter_name = layer_name
|
||||||
|
break
|
||||||
|
|
||||||
|
if matching_adapter_name and isinstance(module, nn.Linear):
|
||||||
|
hook = module.register_forward_hook(
|
||||||
|
self._make_hook_fn(matching_adapter_name)
|
||||||
|
)
|
||||||
|
self.hooks.append(hook)
|
||||||
|
|
||||||
|
def unregister_hooks(self) -> None:
|
||||||
|
"""Removes all registered hooks from the base model."""
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.remove()
|
||||||
|
self.hooks.clear()
|
||||||
|
|
||||||
|
def _make_hook_fn(self, layer_name: str) -> Callable[..., torch.Tensor]:
|
||||||
|
"""Creates the hook function for a specific linear layer."""
|
||||||
|
|
||||||
|
def hook_fn(
|
||||||
|
module: nn.Module,
|
||||||
|
input_tensor: Tuple[torch.Tensor, ...],
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x = input_tensor[0] # [batch, seq_len, in_features]
|
||||||
|
|
||||||
|
# Calculate gate weights
|
||||||
|
if self.current_gate_weights is not None:
|
||||||
|
# Use manually set priors (e.g. fuzzy text-based)
|
||||||
|
# Expand to match batch size
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
weights = (
|
||||||
|
self.current_gate_weights.to(x.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(batch_size, -1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Compute dynamically per token via learnable gate
|
||||||
|
# We pool over sequence length or route per token
|
||||||
|
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
|
||||||
|
gate_logits = self.gate(x)
|
||||||
|
weights = torch.softmax(gate_logits, dim=-1)
|
||||||
|
|
||||||
|
# Compute combined low-rank contribution
|
||||||
|
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
|
||||||
|
adapter_output = torch.zeros_like(output_tensor)
|
||||||
|
|
||||||
|
for i, adapter in enumerate(self.adapter_library):
|
||||||
|
if layer_name in adapter.layers:
|
||||||
|
lm = adapter.layers[layer_name]
|
||||||
|
# Ensure tensors are on the correct device
|
||||||
|
lora_A = lm.lora_A.to(x.device)
|
||||||
|
lora_B = lm.lora_B.to(x.device)
|
||||||
|
|
||||||
|
# Dynamic scaling: gate_weight for this adapter
|
||||||
|
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
|
||||||
|
if len(weights.shape) == 3:
|
||||||
|
# Token-level routing: shape [batch, seq_len, 1]
|
||||||
|
g = weights[..., i : i + 1]
|
||||||
|
else:
|
||||||
|
# Batch-level routing: shape [batch, 1, 1]
|
||||||
|
g = weights[:, i].view(-1, 1, 1)
|
||||||
|
|
||||||
|
# Low-rank projection
|
||||||
|
x_proj = torch.matmul(x, lora_A.t())
|
||||||
|
y_proj = torch.matmul(x_proj, lora_B.t())
|
||||||
|
|
||||||
|
# Accumulate scaled delta
|
||||||
|
adapter_output += g * y_proj
|
||||||
|
|
||||||
|
return output_tensor + adapter_output
|
||||||
|
|
||||||
|
return hook_fn
|
||||||
217
python/expert_manifold_alignment.py
Normal file
217
python/expert_manifold_alignment.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""Expert Manifold Alignment for Parasitic QLoRA.
|
||||||
|
|
||||||
|
Aligns and profiles extracted LoRA adapters with optimizer update trajectories
|
||||||
|
and functional layer localization to classify them into legal exam domains:
|
||||||
|
- Statute Recall (embed, early layers)
|
||||||
|
- Logical Reasoning (attn, mid-late layers)
|
||||||
|
- Style / Gutachtenstil (mlp, late layers)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from parasitic_qlora import ExpertAdapter, LoRAMatrices
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertManifoldAligner:
|
||||||
|
"""Aligns and profiles extracted LoRA adapters with the FCES expert manifold.
|
||||||
|
|
||||||
|
Uses layer-wise functional localization (depth and layer type) to categorize
|
||||||
|
adapters into legal reasoning domains (Logic, Statute, Style). Also computes
|
||||||
|
step-wise trajectory alignment with optimizer updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module) -> None:
|
||||||
|
self.total_layers = self._detect_total_layers(model)
|
||||||
|
self.prev_weights: Dict[str, torch.Tensor] = {}
|
||||||
|
# Prime the tracker with current model weights
|
||||||
|
self.track_step(model)
|
||||||
|
|
||||||
|
def _detect_total_layers(self, model: nn.Module) -> int:
|
||||||
|
"""Detects the number of transformer layers dynamically from parameter names."""
|
||||||
|
max_layer_idx = 0
|
||||||
|
found = False
|
||||||
|
for name in model.state_dict().keys():
|
||||||
|
match = re.search(r"layers?\.(\d+)\.", name.lower())
|
||||||
|
if match:
|
||||||
|
max_layer_idx = max(max_layer_idx, int(match.group(1)))
|
||||||
|
found = True
|
||||||
|
return max_layer_idx + 1 if found else 12
|
||||||
|
|
||||||
|
def track_step(self, model: nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Calculates step update δW_t = W_t - W_{t-1} for tracked linear layers."""
|
||||||
|
step_updates: Dict[str, torch.Tensor] = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad or len(param.shape) < 2:
|
||||||
|
continue
|
||||||
|
if name in self.prev_weights:
|
||||||
|
# δW = W_t - W_{t-1}
|
||||||
|
step_updates[name] = param.data - self.prev_weights[name]
|
||||||
|
self.prev_weights[name] = param.data.clone().detach()
|
||||||
|
return step_updates
|
||||||
|
|
||||||
|
def compute_subspace_alignment(
|
||||||
|
self, lora_matrices: LoRAMatrices, step_update: torch.Tensor
|
||||||
|
) -> float:
|
||||||
|
"""Computes cosine similarity between step update and LoRA subspace.
|
||||||
|
|
||||||
|
Uses O(r d k) trace formulation: trace(B^T * X * A^T) / (||BA||_F * ||X||_F)
|
||||||
|
to avoid large matrix allocations.
|
||||||
|
"""
|
||||||
|
B = lora_matrices.lora_B
|
||||||
|
A = lora_matrices.lora_A
|
||||||
|
X = step_update.to(B.device)
|
||||||
|
|
||||||
|
# Ensure matching shapes
|
||||||
|
if B.shape[0] != X.shape[0] or A.shape[1] != X.shape[1]:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 1. Compute ||BA||_F using trace((B^T B)(A A^T))
|
||||||
|
BtB = torch.matmul(B.t(), B)
|
||||||
|
AAt = torch.matmul(A, A.t())
|
||||||
|
norm_BA_sq = torch.sum(BtB * AAt)
|
||||||
|
norm_BA: float = float(torch.sqrt(norm_BA_sq.clamp(min=1e-10)).item())
|
||||||
|
|
||||||
|
# 2. Compute ||X||_F
|
||||||
|
norm_X: float = float(torch.norm(X, p="fro").item())
|
||||||
|
if norm_X < 1e-10 or norm_BA < 1e-10:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 3. Compute trace(B^T X A^T) = sum( (B^T X) * A )
|
||||||
|
BtX = torch.matmul(B.t(), X)
|
||||||
|
dot_product: float = float(torch.sum(BtX * A).item())
|
||||||
|
|
||||||
|
return dot_product / (norm_BA * norm_X)
|
||||||
|
|
||||||
|
def analyze_layer_profile(self, name: str) -> Tuple[str, str]:
|
||||||
|
"""Categorizes a layer by its type (embed, attn, mlp) and depth (early, mid, late)."""
|
||||||
|
nl = name.lower()
|
||||||
|
|
||||||
|
# Determine layer type
|
||||||
|
if "embed" in nl or "wte" in nl:
|
||||||
|
l_type = "embed"
|
||||||
|
elif any(
|
||||||
|
x in nl
|
||||||
|
for x in [
|
||||||
|
"attn",
|
||||||
|
"query",
|
||||||
|
"key",
|
||||||
|
"value",
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"out_proj",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
l_type = "attn"
|
||||||
|
elif any(
|
||||||
|
x in nl
|
||||||
|
for x in [
|
||||||
|
"mlp",
|
||||||
|
"dense_h_to_4h",
|
||||||
|
"dense_4h_to_h",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
l_type = "mlp"
|
||||||
|
else:
|
||||||
|
l_type = "other"
|
||||||
|
|
||||||
|
# Determine layer depth
|
||||||
|
match = re.search(r"layers?\.(\d+)\.", nl)
|
||||||
|
if match:
|
||||||
|
idx = int(match.group(1))
|
||||||
|
early_bound = self.total_layers // 3
|
||||||
|
late_bound = 2 * (self.total_layers // 3)
|
||||||
|
if idx < early_bound:
|
||||||
|
depth = "early"
|
||||||
|
elif idx < late_bound:
|
||||||
|
depth = "mid"
|
||||||
|
else:
|
||||||
|
depth = "late"
|
||||||
|
else:
|
||||||
|
# Fallback for non-indexed layers
|
||||||
|
if "embed" in nl:
|
||||||
|
depth = "early"
|
||||||
|
elif "lm_head" in nl or "head" in nl:
|
||||||
|
depth = "late"
|
||||||
|
else:
|
||||||
|
depth = "mid"
|
||||||
|
|
||||||
|
return l_type, depth
|
||||||
|
|
||||||
|
def profile_adapter(self, adapter: ExpertAdapter) -> Dict[str, float]:
|
||||||
|
"""Calculates domain alignment scores for the adapter.
|
||||||
|
|
||||||
|
Returns similarity coefficients for:
|
||||||
|
- statute_recall (embed/early layers)
|
||||||
|
- logic_reasoning (attn/mid-late layers)
|
||||||
|
- style_gutachtenstil (mlp/late layers)
|
||||||
|
"""
|
||||||
|
energies = {"statute": 0.0, "logic": 0.0, "style": 0.0}
|
||||||
|
total_energy = 0.0
|
||||||
|
|
||||||
|
for name, lm in adapter.layers.items():
|
||||||
|
l_type, depth = self.analyze_layer_profile(name)
|
||||||
|
# Energy of this layer's delta is sum of squares of singular values (Frobenius norm squared)
|
||||||
|
layer_energy = torch.sum(lm.singular_values**2).item()
|
||||||
|
total_energy += layer_energy
|
||||||
|
|
||||||
|
# Map layers to domain profiles
|
||||||
|
if depth == "early" or l_type == "embed":
|
||||||
|
energies["statute"] += layer_energy * 1.5
|
||||||
|
energies["logic"] += layer_energy * 0.5
|
||||||
|
elif depth == "mid":
|
||||||
|
if l_type == "attn":
|
||||||
|
energies["logic"] += layer_energy * 1.5
|
||||||
|
else: # mlp
|
||||||
|
energies["style"] += layer_energy * 1.2
|
||||||
|
energies["logic"] += layer_energy * 0.8
|
||||||
|
else: # late
|
||||||
|
if l_type == "mlp":
|
||||||
|
energies["style"] += layer_energy * 1.8
|
||||||
|
else:
|
||||||
|
energies["logic"] += layer_energy * 1.2
|
||||||
|
energies["style"] += layer_energy * 0.5
|
||||||
|
|
||||||
|
if total_energy < 1e-10:
|
||||||
|
return {
|
||||||
|
"statute_recall": 0.33,
|
||||||
|
"logic_reasoning": 0.33,
|
||||||
|
"style_gutachtenstil": 0.33,
|
||||||
|
}
|
||||||
|
|
||||||
|
sum_scores = sum(energies.values())
|
||||||
|
if sum_scores > 0:
|
||||||
|
scores = {k: v / sum_scores for k, v in energies.items()}
|
||||||
|
else:
|
||||||
|
scores = {"statute": 0.33, "logic": 0.33, "style": 0.33}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"statute_recall": scores["statute"],
|
||||||
|
"logic_reasoning": scores["logic"],
|
||||||
|
"style_gutachtenstil": scores["style"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def tag_adapter(self, adapter: ExpertAdapter) -> List[str]:
|
||||||
|
"""Profiles the adapter and adds the best domain tags to its domain_tags."""
|
||||||
|
scores = self.profile_adapter(adapter)
|
||||||
|
|
||||||
|
# Find dominant domains (above 35% concentration)
|
||||||
|
tags = []
|
||||||
|
for domain, score in scores.items():
|
||||||
|
if score >= 0.35:
|
||||||
|
tags.append(domain)
|
||||||
|
|
||||||
|
if not tags:
|
||||||
|
best_domain = max(scores, key=lambda k: scores[k])
|
||||||
|
tags.append(best_domain)
|
||||||
|
|
||||||
|
adapter.domain_tags = tags
|
||||||
|
return tags
|
||||||
@@ -40,6 +40,10 @@ PYBIND11_MODULE(fces_native, m) {
|
|||||||
.def("restore_from_ram", &fces::FCESOptimizer::restore_from_ram)
|
.def("restore_from_ram", &fces::FCESOptimizer::restore_from_ram)
|
||||||
.def("step_count", &fces::FCESOptimizer::step_count)
|
.def("step_count", &fces::FCESOptimizer::step_count)
|
||||||
.def("calculate_sparsity", &fces::FCESOptimizer::calculate_sparsity)
|
.def("calculate_sparsity", &fces::FCESOptimizer::calculate_sparsity)
|
||||||
|
.def("get_active_controller_id",
|
||||||
|
&fces::FCESOptimizer::get_active_controller_id)
|
||||||
|
.def("get_active_controller_fitness",
|
||||||
|
&fces::FCESOptimizer::get_active_controller_fitness)
|
||||||
.def("zero_grad", [](fces::FCESOptimizer &self) {
|
.def("zero_grad", [](fces::FCESOptimizer &self) {
|
||||||
for (auto &group : self.param_groups()) {
|
for (auto &group : self.param_groups()) {
|
||||||
for (auto &p : group.params()) {
|
for (auto &p : group.params()) {
|
||||||
|
|||||||
@@ -491,4 +491,20 @@ void FCESOptimizer::handle_rollback() {
|
|||||||
Telemetry::get().warning("hard_reset_executed", "rollback_sanitization");
|
Telemetry::get().warning("hard_reset_executed", "rollback_sanitization");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint64_t FCESOptimizer::get_active_controller_id() const {
|
||||||
|
if (!evolution_manager_)
|
||||||
|
return 0;
|
||||||
|
return const_cast<EvolutionManager *>(evolution_manager_.get())
|
||||||
|
->get_active_controller()
|
||||||
|
.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
float FCESOptimizer::get_active_controller_fitness() const {
|
||||||
|
if (!evolution_manager_)
|
||||||
|
return 0.0f;
|
||||||
|
return const_cast<EvolutionManager *>(evolution_manager_.get())
|
||||||
|
->get_active_controller()
|
||||||
|
.fitness;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace fces
|
} // namespace fces
|
||||||
|
|||||||
136
tests/test_adapter_moe_router.py
Normal file
136
tests/test_adapter_moe_router.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Ensure python directory is in path
|
||||||
|
sys.path.append(
|
||||||
|
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||||
|
)
|
||||||
|
|
||||||
|
from parasitic_qlora import ExpertAdapter, LoRAMatrices
|
||||||
|
from adapter_moe_router import ExpertAdapterRouter
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(32, 16, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.fc1(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdapterMoERouter(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.model = SimpleModel()
|
||||||
|
|
||||||
|
# Create dummy expert adapters with domain tags
|
||||||
|
# Adapter 1: Statute Recall (has fc1.weight adapter)
|
||||||
|
self.adapter1 = ExpertAdapter(
|
||||||
|
adapter_id="adapter_statute",
|
||||||
|
step=1,
|
||||||
|
domain_tags=["statute_recall"],
|
||||||
|
layers={
|
||||||
|
"fc1.weight": LoRAMatrices(
|
||||||
|
layer_name="fc1.weight",
|
||||||
|
lora_B=torch.ones(16, 2) * 0.1, # d x r = 16 x 2
|
||||||
|
lora_A=torch.ones(2, 32) * 0.1, # r x k = 2 x 32
|
||||||
|
rank=2,
|
||||||
|
explained_variance=1.0,
|
||||||
|
singular_values=torch.tensor([1.0, 1.0]),
|
||||||
|
original_shape=(16, 32),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adapter 2: Logic (has fc1.weight adapter)
|
||||||
|
self.adapter2 = ExpertAdapter(
|
||||||
|
adapter_id="adapter_logic",
|
||||||
|
step=1,
|
||||||
|
domain_tags=["logic_reasoning"],
|
||||||
|
layers={
|
||||||
|
"fc1.weight": LoRAMatrices(
|
||||||
|
layer_name="fc1.weight",
|
||||||
|
lora_B=torch.ones(16, 2) * 0.2,
|
||||||
|
lora_A=torch.ones(2, 32) * 0.2,
|
||||||
|
rank=2,
|
||||||
|
explained_variance=1.0,
|
||||||
|
singular_values=torch.tensor([1.0, 1.0]),
|
||||||
|
original_shape=(16, 32),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.library = [self.adapter1, self.adapter2]
|
||||||
|
|
||||||
|
def test_fuzzy_prior_calculation(self) -> None:
|
||||||
|
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
|
||||||
|
|
||||||
|
# 1. Text has statutes
|
||||||
|
priors_statute = router.compute_fuzzy_priors("According to § 535 BGB...")
|
||||||
|
# Index 0 is statute recall, index 1 is logic
|
||||||
|
self.assertGreater(priors_statute[0].item(), priors_statute[1].item())
|
||||||
|
|
||||||
|
# 2. Text has reasoning/FIRT
|
||||||
|
priors_logic = router.compute_fuzzy_priors(
|
||||||
|
"We analyze using FIRT reasoning traces"
|
||||||
|
)
|
||||||
|
self.assertGreater(priors_logic[1].item(), priors_logic[0].item())
|
||||||
|
|
||||||
|
def test_hook_registration(self) -> None:
|
||||||
|
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
|
||||||
|
self.assertEqual(len(router.hooks), 0)
|
||||||
|
|
||||||
|
# Register hooks
|
||||||
|
router.register_hooks()
|
||||||
|
self.assertEqual(len(router.hooks), 1)
|
||||||
|
|
||||||
|
# Unregister hooks
|
||||||
|
router.unregister_hooks()
|
||||||
|
self.assertEqual(len(router.hooks), 0)
|
||||||
|
|
||||||
|
def test_forward_pass_with_routing(self) -> None:
|
||||||
|
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
|
||||||
|
router.register_hooks()
|
||||||
|
|
||||||
|
# Mock static active routing to only use adapter 1 (statute)
|
||||||
|
priors = torch.tensor([1.0, 0.0])
|
||||||
|
router.set_active_routing(priors)
|
||||||
|
|
||||||
|
# Input tensor
|
||||||
|
x = torch.ones(1, 4, 32) # batch=1, seq_len=4, in_dim=32
|
||||||
|
|
||||||
|
# 1. Standard forward pass through base model (without hooks)
|
||||||
|
# To get the unadapted output, we can unregister hooks
|
||||||
|
router.unregister_hooks()
|
||||||
|
with torch.no_grad():
|
||||||
|
output_base = self.model(x)
|
||||||
|
|
||||||
|
# 2. Forward pass with active routing
|
||||||
|
router.register_hooks()
|
||||||
|
with torch.no_grad():
|
||||||
|
output_adapted = self.model(x)
|
||||||
|
|
||||||
|
# Check that adapter is applied
|
||||||
|
self.assertFalse(torch.allclose(output_base, output_adapted))
|
||||||
|
|
||||||
|
# Check mathematically:
|
||||||
|
# For adapter 1: lora_B is 16x2 of 0.1, lora_A is 2x32 of 0.1
|
||||||
|
# Input x is all ones of shape [1, 4, 32]
|
||||||
|
# x_proj = x @ lora_A.t() -> shape [1, 4, 2].
|
||||||
|
# Each entry of x_proj is sum_{k=1}^{32} 1.0 * 0.1 = 3.2
|
||||||
|
# y_proj = x_proj @ lora_B.t() -> shape [1, 4, 16].
|
||||||
|
# Each entry of y_proj is sum_{r=1}^2 3.2 * 0.1 = 0.64
|
||||||
|
# Since weight prior is 1.0, adapted output should be output_base + 0.64
|
||||||
|
expected_diff = torch.ones(1, 4, 16) * 0.64
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(output_adapted - output_base, expected_diff, rtol=1e-5)
|
||||||
|
)
|
||||||
|
|
||||||
|
router.unregister_hooks()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
142
tests/test_expert_manifold_alignment.py
Normal file
142
tests/test_expert_manifold_alignment.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Ensure python directory is in path
|
||||||
|
sys.path.append(
|
||||||
|
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||||
|
)
|
||||||
|
|
||||||
|
from parasitic_qlora import ExpertAdapter, LoRAMatrices
|
||||||
|
from expert_manifold_alignment import ExpertManifoldAligner
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(32, 32, bias=False)
|
||||||
|
self.fc2 = nn.Linear(32, 16, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexModel(nn.Module): # type: ignore[misc]
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# Simulated transformer blocks to test depth partitioning
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.ModuleDict(
|
||||||
|
{
|
||||||
|
"self_attn": nn.Linear(32, 32, bias=False),
|
||||||
|
"mlp": nn.Linear(32, 32, bias=False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for _ in range(6)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpertManifoldAlignment(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.simple_model = SimpleModel()
|
||||||
|
self.complex_model = ComplexModel()
|
||||||
|
|
||||||
|
def test_layer_detection(self) -> None:
|
||||||
|
aligner = ExpertManifoldAligner(self.complex_model)
|
||||||
|
self.assertEqual(aligner.total_layers, 6)
|
||||||
|
|
||||||
|
simple_aligner = ExpertManifoldAligner(self.simple_model)
|
||||||
|
# Should fallback to 12 if no indexed layer pattern matches
|
||||||
|
self.assertEqual(simple_aligner.total_layers, 12)
|
||||||
|
|
||||||
|
def test_step_tracking(self) -> None:
|
||||||
|
aligner = ExpertManifoldAligner(self.simple_model)
|
||||||
|
|
||||||
|
# Apply a modification
|
||||||
|
with torch.no_grad():
|
||||||
|
self.simple_model.fc1.weight.add_(torch.ones(32, 32) * 0.5)
|
||||||
|
|
||||||
|
updates = aligner.track_step(self.simple_model)
|
||||||
|
self.assertIn("fc1.weight", updates)
|
||||||
|
self.assertAlmostEqual(updates["fc1.weight"].mean().item(), 0.5, places=5)
|
||||||
|
# fc2 shouldn't be in updates since it did not change (or it's zero)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
updates.get("fc2.weight", torch.zeros(16, 32)), torch.zeros(16, 32)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_subspace_alignment_math(self) -> None:
|
||||||
|
aligner = ExpertManifoldAligner(self.simple_model)
|
||||||
|
|
||||||
|
# Define 2D matrices for LoRA: rank 2, dim 32x32
|
||||||
|
u = torch.zeros(32, 2)
|
||||||
|
u[0, 0] = 1.0
|
||||||
|
u[1, 1] = 1.0
|
||||||
|
|
||||||
|
v = torch.zeros(2, 32)
|
||||||
|
v[0, 0] = 1.0
|
||||||
|
v[1, 1] = 1.0
|
||||||
|
|
||||||
|
# Delta is BA = u v = diag(1, 1, 0, ...)
|
||||||
|
lora_matrices = LoRAMatrices(
|
||||||
|
layer_name="fc1.weight",
|
||||||
|
lora_B=u,
|
||||||
|
lora_A=v,
|
||||||
|
rank=2,
|
||||||
|
explained_variance=1.0,
|
||||||
|
singular_values=torch.tensor([1.0, 1.0]),
|
||||||
|
original_shape=(32, 32),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Step update exactly in the subspace of lora_matrices
|
||||||
|
step_update_aligned = torch.zeros(32, 32)
|
||||||
|
step_update_aligned[0, 0] = 2.0
|
||||||
|
step_update_aligned[1, 1] = 2.0
|
||||||
|
|
||||||
|
alignment = aligner.compute_subspace_alignment(
|
||||||
|
lora_matrices, step_update_aligned
|
||||||
|
)
|
||||||
|
# Cosine similarity should be 1.0 (since the direction is fully aligned)
|
||||||
|
self.assertAlmostEqual(alignment, 1.0, places=5)
|
||||||
|
|
||||||
|
# 2. Step update orthogonal to the subspace
|
||||||
|
step_update_ortho = torch.zeros(32, 32)
|
||||||
|
step_update_ortho[2, 2] = 1.0
|
||||||
|
|
||||||
|
alignment_ortho = aligner.compute_subspace_alignment(
|
||||||
|
lora_matrices, step_update_ortho
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(alignment_ortho, 0.0, places=5)
|
||||||
|
|
||||||
|
def test_domain_profiling(self) -> None:
|
||||||
|
aligner = ExpertManifoldAligner(self.complex_model)
|
||||||
|
|
||||||
|
# Create dummy adapter with layer concentrated in early self_attn (Statute recall)
|
||||||
|
adapter_statute = ExpertAdapter(
|
||||||
|
adapter_id="test_statute",
|
||||||
|
step=1,
|
||||||
|
layers={
|
||||||
|
"layers.0.self_attn.weight": LoRAMatrices(
|
||||||
|
layer_name="layers.0.self_attn.weight",
|
||||||
|
lora_B=torch.randn(32, 4),
|
||||||
|
lora_A=torch.randn(4, 32),
|
||||||
|
rank=4,
|
||||||
|
explained_variance=0.9,
|
||||||
|
singular_values=torch.ones(4) * 10.0, # high energy
|
||||||
|
original_shape=(32, 32),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
profile = aligner.profile_adapter(adapter_statute)
|
||||||
|
self.assertGreater(profile["statute_recall"], profile["logic_reasoning"])
|
||||||
|
self.assertGreater(profile["statute_recall"], profile["style_gutachtenstil"])
|
||||||
|
|
||||||
|
tags = aligner.tag_adapter(adapter_statute)
|
||||||
|
self.assertIn("statute_recall", tags)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user