feat: implement parasitic QLoRA adapter extraction and unit tests
This commit is contained in:
109
tests/test_parasitic_qlora.py
Normal file
109
tests/test_parasitic_qlora.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Ensure python directory is in path
|
||||
sys.path.append(
|
||||
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||
)
|
||||
|
||||
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig
|
||||
|
||||
|
||||
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(32, 32, bias=False)
|
||||
self.fc2 = nn.Linear(32, 16, bias=False)
|
||||
|
||||
|
||||
class TestParasiticQLoRA(unittest.TestCase):
|
||||
model: SimpleModel
|
||||
config: QLoRAConfig
|
||||
extractor: ParasiticQLoRAExtractor
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model = SimpleModel()
|
||||
self.config = QLoRAConfig(
|
||||
min_rank=2,
|
||||
max_rank=8,
|
||||
explained_variance_threshold=0.9,
|
||||
interesting_point_detection=False,
|
||||
extraction_interval=1,
|
||||
)
|
||||
self.extractor = ParasiticQLoRAExtractor(self.config)
|
||||
|
||||
def test_snapshot_and_extraction(self) -> None:
|
||||
# 1. Take base snapshot
|
||||
self.extractor.snapshot_base(self.model)
|
||||
self.assertEqual(len(self.extractor.base_weights), 2)
|
||||
|
||||
# Verify weight hashes are stored
|
||||
for name in ["fc1.weight", "fc2.weight"]:
|
||||
self.assertIn(name, self.extractor.base_weights)
|
||||
self.assertTrue(isinstance(self.extractor.base_weights[name], torch.Tensor))
|
||||
|
||||
# 2. Simulate weight changes (as in fine-tuning)
|
||||
# We add a low-rank delta to fc1.weight (rank 2)
|
||||
u = torch.randn(32, 2)
|
||||
v = torch.randn(2, 32)
|
||||
delta_w1 = torch.matmul(u, v) * 0.1
|
||||
|
||||
with torch.no_grad():
|
||||
self.model.fc1.weight.add_(delta_w1)
|
||||
|
||||
# 3. Perform extraction
|
||||
step = 1
|
||||
metrics = {"loss": 0.5, "step": step}
|
||||
self.extractor.extract_adapters(self.model, step, metrics)
|
||||
|
||||
# Verify adapter is stored in library
|
||||
self.assertEqual(len(self.extractor.adapter_library), 1)
|
||||
adapter = self.extractor.adapter_library[0]
|
||||
self.assertTrue(adapter.adapter_id.startswith("unknown_step1_"))
|
||||
|
||||
self.assertIn("fc1.weight", adapter.layers)
|
||||
|
||||
# Verify shapes of A and B
|
||||
lora_A = adapter.layers["fc1.weight"].lora_A
|
||||
lora_B = adapter.layers["fc1.weight"].lora_B
|
||||
rank = adapter.layers["fc1.weight"].rank
|
||||
|
||||
self.assertEqual(lora_A.shape, (rank, 32))
|
||||
self.assertEqual(lora_B.shape, (32, rank))
|
||||
self.assertGreaterEqual(rank, self.config.min_rank)
|
||||
self.assertLessEqual(rank, self.config.max_rank)
|
||||
|
||||
def test_save_and_load(self) -> None:
|
||||
self.extractor.snapshot_base(self.model)
|
||||
|
||||
# Make a change
|
||||
with torch.no_grad():
|
||||
self.model.fc2.weight.add_(torch.randn(16, 32) * 0.05)
|
||||
|
||||
self.extractor.extract_adapters(self.model, 1, {"loss": 0.2})
|
||||
|
||||
# Save library
|
||||
test_path = "test_adapters.pt"
|
||||
self.extractor.save_library(test_path)
|
||||
self.assertTrue(os.path.exists(test_path))
|
||||
|
||||
# Load in a new extractor
|
||||
new_extractor = ParasiticQLoRAExtractor(self.config)
|
||||
loaded = new_extractor.load_library(test_path)
|
||||
new_extractor.adapter_library = loaded
|
||||
|
||||
self.assertEqual(len(new_extractor.adapter_library), 1)
|
||||
orig_adapter = self.extractor.adapter_library[0]
|
||||
loaded_adapter = new_extractor.adapter_library[0]
|
||||
self.assertEqual(orig_adapter.adapter_id, loaded_adapter.adapter_id)
|
||||
|
||||
# Clean up
|
||||
if os.path.exists(test_path):
|
||||
os.remove(test_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user