feat: expert manifold alignment, MoE router, FCES controller metadata bindings
This commit is contained in:
136
tests/test_adapter_moe_router.py
Normal file
136
tests/test_adapter_moe_router.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Ensure python directory is in path
|
||||
sys.path.append(
|
||||
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||
)
|
||||
|
||||
from parasitic_qlora import ExpertAdapter, LoRAMatrices
|
||||
from adapter_moe_router import ExpertAdapterRouter
|
||||
|
||||
|
||||
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(32, 16, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.fc1(x)
|
||||
|
||||
|
||||
class TestAdapterMoERouter(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model = SimpleModel()
|
||||
|
||||
# Create dummy expert adapters with domain tags
|
||||
# Adapter 1: Statute Recall (has fc1.weight adapter)
|
||||
self.adapter1 = ExpertAdapter(
|
||||
adapter_id="adapter_statute",
|
||||
step=1,
|
||||
domain_tags=["statute_recall"],
|
||||
layers={
|
||||
"fc1.weight": LoRAMatrices(
|
||||
layer_name="fc1.weight",
|
||||
lora_B=torch.ones(16, 2) * 0.1, # d x r = 16 x 2
|
||||
lora_A=torch.ones(2, 32) * 0.1, # r x k = 2 x 32
|
||||
rank=2,
|
||||
explained_variance=1.0,
|
||||
singular_values=torch.tensor([1.0, 1.0]),
|
||||
original_shape=(16, 32),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Adapter 2: Logic (has fc1.weight adapter)
|
||||
self.adapter2 = ExpertAdapter(
|
||||
adapter_id="adapter_logic",
|
||||
step=1,
|
||||
domain_tags=["logic_reasoning"],
|
||||
layers={
|
||||
"fc1.weight": LoRAMatrices(
|
||||
layer_name="fc1.weight",
|
||||
lora_B=torch.ones(16, 2) * 0.2,
|
||||
lora_A=torch.ones(2, 32) * 0.2,
|
||||
rank=2,
|
||||
explained_variance=1.0,
|
||||
singular_values=torch.tensor([1.0, 1.0]),
|
||||
original_shape=(16, 32),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
self.library = [self.adapter1, self.adapter2]
|
||||
|
||||
def test_fuzzy_prior_calculation(self) -> None:
|
||||
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
|
||||
|
||||
# 1. Text has statutes
|
||||
priors_statute = router.compute_fuzzy_priors("According to § 535 BGB...")
|
||||
# Index 0 is statute recall, index 1 is logic
|
||||
self.assertGreater(priors_statute[0].item(), priors_statute[1].item())
|
||||
|
||||
# 2. Text has reasoning/FIRT
|
||||
priors_logic = router.compute_fuzzy_priors(
|
||||
"We analyze using FIRT reasoning traces"
|
||||
)
|
||||
self.assertGreater(priors_logic[1].item(), priors_logic[0].item())
|
||||
|
||||
def test_hook_registration(self) -> None:
|
||||
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
|
||||
self.assertEqual(len(router.hooks), 0)
|
||||
|
||||
# Register hooks
|
||||
router.register_hooks()
|
||||
self.assertEqual(len(router.hooks), 1)
|
||||
|
||||
# Unregister hooks
|
||||
router.unregister_hooks()
|
||||
self.assertEqual(len(router.hooks), 0)
|
||||
|
||||
def test_forward_pass_with_routing(self) -> None:
|
||||
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
|
||||
router.register_hooks()
|
||||
|
||||
# Mock static active routing to only use adapter 1 (statute)
|
||||
priors = torch.tensor([1.0, 0.0])
|
||||
router.set_active_routing(priors)
|
||||
|
||||
# Input tensor
|
||||
x = torch.ones(1, 4, 32) # batch=1, seq_len=4, in_dim=32
|
||||
|
||||
# 1. Standard forward pass through base model (without hooks)
|
||||
# To get the unadapted output, we can unregister hooks
|
||||
router.unregister_hooks()
|
||||
with torch.no_grad():
|
||||
output_base = self.model(x)
|
||||
|
||||
# 2. Forward pass with active routing
|
||||
router.register_hooks()
|
||||
with torch.no_grad():
|
||||
output_adapted = self.model(x)
|
||||
|
||||
# Check that adapter is applied
|
||||
self.assertFalse(torch.allclose(output_base, output_adapted))
|
||||
|
||||
# Check mathematically:
|
||||
# For adapter 1: lora_B is 16x2 of 0.1, lora_A is 2x32 of 0.1
|
||||
# Input x is all ones of shape [1, 4, 32]
|
||||
# x_proj = x @ lora_A.t() -> shape [1, 4, 2].
|
||||
# Each entry of x_proj is sum_{k=1}^{32} 1.0 * 0.1 = 3.2
|
||||
# y_proj = x_proj @ lora_B.t() -> shape [1, 4, 16].
|
||||
# Each entry of y_proj is sum_{r=1}^2 3.2 * 0.1 = 0.64
|
||||
# Since weight prior is 1.0, adapted output should be output_base + 0.64
|
||||
expected_diff = torch.ones(1, 4, 16) * 0.64
|
||||
self.assertTrue(
|
||||
torch.allclose(output_adapted - output_base, expected_diff, rtol=1e-5)
|
||||
)
|
||||
|
||||
router.unregister_hooks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
142
tests/test_expert_manifold_alignment.py
Normal file
142
tests/test_expert_manifold_alignment.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Ensure python directory is in path
|
||||
sys.path.append(
|
||||
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||
)
|
||||
|
||||
from parasitic_qlora import ExpertAdapter, LoRAMatrices
|
||||
from expert_manifold_alignment import ExpertManifoldAligner
|
||||
|
||||
|
||||
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(32, 32, bias=False)
|
||||
self.fc2 = nn.Linear(32, 16, bias=False)
|
||||
|
||||
|
||||
class ComplexModel(nn.Module): # type: ignore[misc]
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Simulated transformer blocks to test depth partitioning
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
nn.ModuleDict(
|
||||
{
|
||||
"self_attn": nn.Linear(32, 32, bias=False),
|
||||
"mlp": nn.Linear(32, 32, bias=False),
|
||||
}
|
||||
)
|
||||
for _ in range(6)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestExpertManifoldAlignment(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.simple_model = SimpleModel()
|
||||
self.complex_model = ComplexModel()
|
||||
|
||||
def test_layer_detection(self) -> None:
|
||||
aligner = ExpertManifoldAligner(self.complex_model)
|
||||
self.assertEqual(aligner.total_layers, 6)
|
||||
|
||||
simple_aligner = ExpertManifoldAligner(self.simple_model)
|
||||
# Should fallback to 12 if no indexed layer pattern matches
|
||||
self.assertEqual(simple_aligner.total_layers, 12)
|
||||
|
||||
def test_step_tracking(self) -> None:
|
||||
aligner = ExpertManifoldAligner(self.simple_model)
|
||||
|
||||
# Apply a modification
|
||||
with torch.no_grad():
|
||||
self.simple_model.fc1.weight.add_(torch.ones(32, 32) * 0.5)
|
||||
|
||||
updates = aligner.track_step(self.simple_model)
|
||||
self.assertIn("fc1.weight", updates)
|
||||
self.assertAlmostEqual(updates["fc1.weight"].mean().item(), 0.5, places=5)
|
||||
# fc2 shouldn't be in updates since it did not change (or it's zero)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
updates.get("fc2.weight", torch.zeros(16, 32)), torch.zeros(16, 32)
|
||||
)
|
||||
)
|
||||
|
||||
def test_subspace_alignment_math(self) -> None:
|
||||
aligner = ExpertManifoldAligner(self.simple_model)
|
||||
|
||||
# Define 2D matrices for LoRA: rank 2, dim 32x32
|
||||
u = torch.zeros(32, 2)
|
||||
u[0, 0] = 1.0
|
||||
u[1, 1] = 1.0
|
||||
|
||||
v = torch.zeros(2, 32)
|
||||
v[0, 0] = 1.0
|
||||
v[1, 1] = 1.0
|
||||
|
||||
# Delta is BA = u v = diag(1, 1, 0, ...)
|
||||
lora_matrices = LoRAMatrices(
|
||||
layer_name="fc1.weight",
|
||||
lora_B=u,
|
||||
lora_A=v,
|
||||
rank=2,
|
||||
explained_variance=1.0,
|
||||
singular_values=torch.tensor([1.0, 1.0]),
|
||||
original_shape=(32, 32),
|
||||
)
|
||||
|
||||
# 1. Step update exactly in the subspace of lora_matrices
|
||||
step_update_aligned = torch.zeros(32, 32)
|
||||
step_update_aligned[0, 0] = 2.0
|
||||
step_update_aligned[1, 1] = 2.0
|
||||
|
||||
alignment = aligner.compute_subspace_alignment(
|
||||
lora_matrices, step_update_aligned
|
||||
)
|
||||
# Cosine similarity should be 1.0 (since the direction is fully aligned)
|
||||
self.assertAlmostEqual(alignment, 1.0, places=5)
|
||||
|
||||
# 2. Step update orthogonal to the subspace
|
||||
step_update_ortho = torch.zeros(32, 32)
|
||||
step_update_ortho[2, 2] = 1.0
|
||||
|
||||
alignment_ortho = aligner.compute_subspace_alignment(
|
||||
lora_matrices, step_update_ortho
|
||||
)
|
||||
self.assertAlmostEqual(alignment_ortho, 0.0, places=5)
|
||||
|
||||
def test_domain_profiling(self) -> None:
|
||||
aligner = ExpertManifoldAligner(self.complex_model)
|
||||
|
||||
# Create dummy adapter with layer concentrated in early self_attn (Statute recall)
|
||||
adapter_statute = ExpertAdapter(
|
||||
adapter_id="test_statute",
|
||||
step=1,
|
||||
layers={
|
||||
"layers.0.self_attn.weight": LoRAMatrices(
|
||||
layer_name="layers.0.self_attn.weight",
|
||||
lora_B=torch.randn(32, 4),
|
||||
lora_A=torch.randn(4, 32),
|
||||
rank=4,
|
||||
explained_variance=0.9,
|
||||
singular_values=torch.ones(4) * 10.0, # high energy
|
||||
original_shape=(32, 32),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
profile = aligner.profile_adapter(adapter_statute)
|
||||
self.assertGreater(profile["statute_recall"], profile["logic_reasoning"])
|
||||
self.assertGreater(profile["statute_recall"], profile["style_gutachtenstil"])
|
||||
|
||||
tags = aligner.tag_adapter(adapter_statute)
|
||||
self.assertIn("statute_recall", tags)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user