import os import sys import unittest import torch import torch.nn as nn # Ensure python directory is in path sys.path.append( os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python") ) from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig class SimpleModel(nn.Module): # type: ignore[misc] def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(32, 32, bias=False) self.fc2 = nn.Linear(32, 16, bias=False) class TestParasiticQLoRA(unittest.TestCase): model: SimpleModel config: QLoRAConfig extractor: ParasiticQLoRAExtractor def setUp(self) -> None: self.model = SimpleModel() self.config = QLoRAConfig( min_rank=2, max_rank=8, explained_variance_threshold=0.9, interesting_point_detection=False, extraction_interval=1, ) self.extractor = ParasiticQLoRAExtractor(self.config) def test_snapshot_and_extraction(self) -> None: # 1. Take base snapshot self.extractor.snapshot_base(self.model) self.assertEqual(len(self.extractor.base_weights), 2) # Verify weight hashes are stored for name in ["fc1.weight", "fc2.weight"]: self.assertIn(name, self.extractor.base_weights) self.assertTrue(isinstance(self.extractor.base_weights[name], torch.Tensor)) # 2. Simulate weight changes (as in fine-tuning) # We add a low-rank delta to fc1.weight (rank 2) u = torch.randn(32, 2) v = torch.randn(2, 32) delta_w1 = torch.matmul(u, v) * 0.1 with torch.no_grad(): self.model.fc1.weight.add_(delta_w1) # 3. Perform extraction step = 1 metrics = {"loss": 0.5, "step": step} self.extractor.extract_adapters(self.model, step, metrics) # Verify adapter is stored in library self.assertEqual(len(self.extractor.adapter_library), 1) adapter = self.extractor.adapter_library[0] self.assertTrue(adapter.adapter_id.startswith("unknown_step1_")) self.assertIn("fc1.weight", adapter.layers) # Verify shapes of A and B lora_A = adapter.layers["fc1.weight"].lora_A lora_B = adapter.layers["fc1.weight"].lora_B rank = adapter.layers["fc1.weight"].rank self.assertEqual(lora_A.shape, (rank, 32)) self.assertEqual(lora_B.shape, (32, rank)) self.assertGreaterEqual(rank, self.config.min_rank) self.assertLessEqual(rank, self.config.max_rank) def test_save_and_load(self) -> None: self.extractor.snapshot_base(self.model) # Make a change with torch.no_grad(): self.model.fc2.weight.add_(torch.randn(16, 32) * 0.05) self.extractor.extract_adapters(self.model, 1, {"loss": 0.2}) # Save library test_path = "test_adapters.pt" self.extractor.save_library(test_path) self.assertTrue(os.path.exists(test_path)) # Load in a new extractor new_extractor = ParasiticQLoRAExtractor(self.config) loaded = new_extractor.load_library(test_path) new_extractor.adapter_library = loaded self.assertEqual(len(new_extractor.adapter_library), 1) orig_adapter = self.extractor.adapter_library[0] loaded_adapter = new_extractor.adapter_library[0] self.assertEqual(orig_adapter.adapter_id, loaded_adapter.adapter_id) # Clean up if os.path.exists(test_path): os.remove(test_path) if __name__ == "__main__": unittest.main()