import os import sys import unittest import torch import torch.nn as nn # Ensure python directory is in path sys.path.append( os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python") ) from parasitic_qlora import ExpertAdapter, LoRAMatrices from adapter_moe_router import ExpertAdapterRouter class SimpleModel(nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(32, 16, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc1(x) class TestAdapterMoERouter(unittest.TestCase): def setUp(self) -> None: self.model = SimpleModel() # Create dummy expert adapters with domain tags # Adapter 1: Statute Recall (has fc1.weight adapter) self.adapter1 = ExpertAdapter( adapter_id="adapter_statute", step=1, domain_tags=["statute_recall"], layers={ "fc1.weight": LoRAMatrices( layer_name="fc1.weight", lora_B=torch.ones(16, 2) * 0.1, # d x r = 16 x 2 lora_A=torch.ones(2, 32) * 0.1, # r x k = 2 x 32 rank=2, explained_variance=1.0, singular_values=torch.tensor([1.0, 1.0]), original_shape=(16, 32), ) }, ) # Adapter 2: Logic (has fc1.weight adapter) self.adapter2 = ExpertAdapter( adapter_id="adapter_logic", step=1, domain_tags=["logic_reasoning"], layers={ "fc1.weight": LoRAMatrices( layer_name="fc1.weight", lora_B=torch.ones(16, 2) * 0.2, lora_A=torch.ones(2, 32) * 0.2, rank=2, explained_variance=1.0, singular_values=torch.tensor([1.0, 1.0]), original_shape=(16, 32), ) }, ) self.library = [self.adapter1, self.adapter2] def test_fuzzy_prior_calculation(self) -> None: router = ExpertAdapterRouter(self.model, self.library, in_features=32) # 1. Text has statutes priors_statute = router.compute_fuzzy_priors("According to ยง 535 BGB...") # Index 0 is statute recall, index 1 is logic self.assertGreater(priors_statute[0].item(), priors_statute[1].item()) # 2. Text has reasoning/FIRT priors_logic = router.compute_fuzzy_priors( "We analyze using FIRT reasoning traces" ) self.assertGreater(priors_logic[1].item(), priors_logic[0].item()) def test_hook_registration(self) -> None: router = ExpertAdapterRouter(self.model, self.library, in_features=32) self.assertEqual(len(router.hooks), 0) # Register hooks router.register_hooks() self.assertEqual(len(router.hooks), 1) # Unregister hooks router.unregister_hooks() self.assertEqual(len(router.hooks), 0) def test_forward_pass_with_routing(self) -> None: router = ExpertAdapterRouter(self.model, self.library, in_features=32) router.register_hooks() # Mock static active routing to only use adapter 1 (statute) priors = torch.tensor([1.0, 0.0]) router.set_active_routing(priors) # Input tensor x = torch.ones(1, 4, 32) # batch=1, seq_len=4, in_dim=32 # 1. Standard forward pass through base model (without hooks) # To get the unadapted output, we can unregister hooks router.unregister_hooks() with torch.no_grad(): output_base = self.model(x) # 2. Forward pass with active routing router.register_hooks() with torch.no_grad(): output_adapted = self.model(x) # Check that adapter is applied self.assertFalse(torch.allclose(output_base, output_adapted)) # Check mathematically: # For adapter 1: lora_B is 16x2 of 0.1, lora_A is 2x32 of 0.1 # Input x is all ones of shape [1, 4, 32] # x_proj = x @ lora_A.t() -> shape [1, 4, 2]. # Each entry of x_proj is sum_{k=1}^{32} 1.0 * 0.1 = 3.2 # y_proj = x_proj @ lora_B.t() -> shape [1, 4, 16]. # Each entry of y_proj is sum_{r=1}^2 3.2 * 0.1 = 0.64 # Since weight prior is 1.0, adapted output should be output_base + 0.64 expected_diff = torch.ones(1, 4, 16) * 0.64 self.assertTrue( torch.allclose(output_adapted - output_base, expected_diff, rtol=1e-5) ) router.unregister_hooks() if __name__ == "__main__": unittest.main()