Files
FCES-native/tests/test_adapter_moe_router.py

137 lines
4.6 KiB
Python

import os
import sys
import unittest
import torch
import torch.nn as nn
# Ensure python directory is in path
sys.path.append(
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
)
from parasitic_qlora import ExpertAdapter, LoRAMatrices
from adapter_moe_router import ExpertAdapterRouter
class SimpleModel(nn.Module): # type: ignore[misc]
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(32, 16, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc1(x)
class TestAdapterMoERouter(unittest.TestCase):
def setUp(self) -> None:
self.model = SimpleModel()
# Create dummy expert adapters with domain tags
# Adapter 1: Statute Recall (has fc1.weight adapter)
self.adapter1 = ExpertAdapter(
adapter_id="adapter_statute",
step=1,
domain_tags=["statute_recall"],
layers={
"fc1.weight": LoRAMatrices(
layer_name="fc1.weight",
lora_B=torch.ones(16, 2) * 0.1, # d x r = 16 x 2
lora_A=torch.ones(2, 32) * 0.1, # r x k = 2 x 32
rank=2,
explained_variance=1.0,
singular_values=torch.tensor([1.0, 1.0]),
original_shape=(16, 32),
)
},
)
# Adapter 2: Logic (has fc1.weight adapter)
self.adapter2 = ExpertAdapter(
adapter_id="adapter_logic",
step=1,
domain_tags=["logic_reasoning"],
layers={
"fc1.weight": LoRAMatrices(
layer_name="fc1.weight",
lora_B=torch.ones(16, 2) * 0.2,
lora_A=torch.ones(2, 32) * 0.2,
rank=2,
explained_variance=1.0,
singular_values=torch.tensor([1.0, 1.0]),
original_shape=(16, 32),
)
},
)
self.library = [self.adapter1, self.adapter2]
def test_fuzzy_prior_calculation(self) -> None:
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
# 1. Text has statutes
priors_statute = router.compute_fuzzy_priors("According to § 535 BGB...")
# Index 0 is statute recall, index 1 is logic
self.assertGreater(priors_statute[0].item(), priors_statute[1].item())
# 2. Text has reasoning/FIRT
priors_logic = router.compute_fuzzy_priors(
"We analyze using FIRT reasoning traces"
)
self.assertGreater(priors_logic[1].item(), priors_logic[0].item())
def test_hook_registration(self) -> None:
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
self.assertEqual(len(router.hooks), 0)
# Register hooks
router.register_hooks()
self.assertEqual(len(router.hooks), 1)
# Unregister hooks
router.unregister_hooks()
self.assertEqual(len(router.hooks), 0)
def test_forward_pass_with_routing(self) -> None:
router = ExpertAdapterRouter(self.model, self.library, in_features=32)
router.register_hooks()
# Mock static active routing to only use adapter 1 (statute)
priors = torch.tensor([1.0, 0.0])
router.set_active_routing(priors)
# Input tensor
x = torch.ones(1, 4, 32) # batch=1, seq_len=4, in_dim=32
# 1. Standard forward pass through base model (without hooks)
# To get the unadapted output, we can unregister hooks
router.unregister_hooks()
with torch.no_grad():
output_base = self.model(x)
# 2. Forward pass with active routing
router.register_hooks()
with torch.no_grad():
output_adapted = self.model(x)
# Check that adapter is applied
self.assertFalse(torch.allclose(output_base, output_adapted))
# Check mathematically:
# For adapter 1: lora_B is 16x2 of 0.1, lora_A is 2x32 of 0.1
# Input x is all ones of shape [1, 4, 32]
# x_proj = x @ lora_A.t() -> shape [1, 4, 2].
# Each entry of x_proj is sum_{k=1}^{32} 1.0 * 0.1 = 3.2
# y_proj = x_proj @ lora_B.t() -> shape [1, 4, 16].
# Each entry of y_proj is sum_{r=1}^2 3.2 * 0.1 = 0.64
# Since weight prior is 1.0, adapted output should be output_base + 0.64
expected_diff = torch.ones(1, 4, 16) * 0.64
self.assertTrue(
torch.allclose(output_adapted - output_base, expected_diff, rtol=1e-5)
)
router.unregister_hooks()
if __name__ == "__main__":
unittest.main()