110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# Ensure python directory is in path
|
|
sys.path.append(
|
|
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
|
)
|
|
|
|
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig
|
|
|
|
|
|
class SimpleModel(nn.Module): # type: ignore[misc]
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(32, 32, bias=False)
|
|
self.fc2 = nn.Linear(32, 16, bias=False)
|
|
|
|
|
|
class TestParasiticQLoRA(unittest.TestCase):
|
|
model: SimpleModel
|
|
config: QLoRAConfig
|
|
extractor: ParasiticQLoRAExtractor
|
|
|
|
def setUp(self) -> None:
|
|
self.model = SimpleModel()
|
|
self.config = QLoRAConfig(
|
|
min_rank=2,
|
|
max_rank=8,
|
|
explained_variance_threshold=0.9,
|
|
interesting_point_detection=False,
|
|
extraction_interval=1,
|
|
)
|
|
self.extractor = ParasiticQLoRAExtractor(self.config)
|
|
|
|
def test_snapshot_and_extraction(self) -> None:
|
|
# 1. Take base snapshot
|
|
self.extractor.snapshot_base(self.model)
|
|
self.assertEqual(len(self.extractor.base_weights), 2)
|
|
|
|
# Verify weight hashes are stored
|
|
for name in ["fc1.weight", "fc2.weight"]:
|
|
self.assertIn(name, self.extractor.base_weights)
|
|
self.assertTrue(isinstance(self.extractor.base_weights[name], torch.Tensor))
|
|
|
|
# 2. Simulate weight changes (as in fine-tuning)
|
|
# We add a low-rank delta to fc1.weight (rank 2)
|
|
u = torch.randn(32, 2)
|
|
v = torch.randn(2, 32)
|
|
delta_w1 = torch.matmul(u, v) * 0.1
|
|
|
|
with torch.no_grad():
|
|
self.model.fc1.weight.add_(delta_w1)
|
|
|
|
# 3. Perform extraction
|
|
step = 1
|
|
metrics = {"loss": 0.5, "step": step}
|
|
self.extractor.extract_adapters(self.model, step, metrics)
|
|
|
|
# Verify adapter is stored in library
|
|
self.assertEqual(len(self.extractor.adapter_library), 1)
|
|
adapter = self.extractor.adapter_library[0]
|
|
self.assertTrue(adapter.adapter_id.startswith("unknown_step1_"))
|
|
|
|
self.assertIn("fc1.weight", adapter.layers)
|
|
|
|
# Verify shapes of A and B
|
|
lora_A = adapter.layers["fc1.weight"].lora_A
|
|
lora_B = adapter.layers["fc1.weight"].lora_B
|
|
rank = adapter.layers["fc1.weight"].rank
|
|
|
|
self.assertEqual(lora_A.shape, (rank, 32))
|
|
self.assertEqual(lora_B.shape, (32, rank))
|
|
self.assertGreaterEqual(rank, self.config.min_rank)
|
|
self.assertLessEqual(rank, self.config.max_rank)
|
|
|
|
def test_save_and_load(self) -> None:
|
|
self.extractor.snapshot_base(self.model)
|
|
|
|
# Make a change
|
|
with torch.no_grad():
|
|
self.model.fc2.weight.add_(torch.randn(16, 32) * 0.05)
|
|
|
|
self.extractor.extract_adapters(self.model, 1, {"loss": 0.2})
|
|
|
|
# Save library
|
|
test_path = "test_adapters.pt"
|
|
self.extractor.save_library(test_path)
|
|
self.assertTrue(os.path.exists(test_path))
|
|
|
|
# Load in a new extractor
|
|
new_extractor = ParasiticQLoRAExtractor(self.config)
|
|
loaded = new_extractor.load_library(test_path)
|
|
new_extractor.adapter_library = loaded
|
|
|
|
self.assertEqual(len(new_extractor.adapter_library), 1)
|
|
orig_adapter = self.extractor.adapter_library[0]
|
|
loaded_adapter = new_extractor.adapter_library[0]
|
|
self.assertEqual(orig_adapter.adapter_id, loaded_adapter.adapter_id)
|
|
|
|
# Clean up
|
|
if os.path.exists(test_path):
|
|
os.remove(test_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|