143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# Ensure python directory is in path
|
|
sys.path.append(
|
|
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
|
)
|
|
|
|
from parasitic_qlora import ExpertAdapter, LoRAMatrices
|
|
from expert_manifold_alignment import ExpertManifoldAligner
|
|
|
|
|
|
class SimpleModel(nn.Module): # type: ignore[misc]
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(32, 32, bias=False)
|
|
self.fc2 = nn.Linear(32, 16, bias=False)
|
|
|
|
|
|
class ComplexModel(nn.Module): # type: ignore[misc]
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
# Simulated transformer blocks to test depth partitioning
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
nn.ModuleDict(
|
|
{
|
|
"self_attn": nn.Linear(32, 32, bias=False),
|
|
"mlp": nn.Linear(32, 32, bias=False),
|
|
}
|
|
)
|
|
for _ in range(6)
|
|
]
|
|
)
|
|
|
|
|
|
class TestExpertManifoldAlignment(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.simple_model = SimpleModel()
|
|
self.complex_model = ComplexModel()
|
|
|
|
def test_layer_detection(self) -> None:
|
|
aligner = ExpertManifoldAligner(self.complex_model)
|
|
self.assertEqual(aligner.total_layers, 6)
|
|
|
|
simple_aligner = ExpertManifoldAligner(self.simple_model)
|
|
# Should fallback to 12 if no indexed layer pattern matches
|
|
self.assertEqual(simple_aligner.total_layers, 12)
|
|
|
|
def test_step_tracking(self) -> None:
|
|
aligner = ExpertManifoldAligner(self.simple_model)
|
|
|
|
# Apply a modification
|
|
with torch.no_grad():
|
|
self.simple_model.fc1.weight.add_(torch.ones(32, 32) * 0.5)
|
|
|
|
updates = aligner.track_step(self.simple_model)
|
|
self.assertIn("fc1.weight", updates)
|
|
self.assertAlmostEqual(updates["fc1.weight"].mean().item(), 0.5, places=5)
|
|
# fc2 shouldn't be in updates since it did not change (or it's zero)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
updates.get("fc2.weight", torch.zeros(16, 32)), torch.zeros(16, 32)
|
|
)
|
|
)
|
|
|
|
def test_subspace_alignment_math(self) -> None:
|
|
aligner = ExpertManifoldAligner(self.simple_model)
|
|
|
|
# Define 2D matrices for LoRA: rank 2, dim 32x32
|
|
u = torch.zeros(32, 2)
|
|
u[0, 0] = 1.0
|
|
u[1, 1] = 1.0
|
|
|
|
v = torch.zeros(2, 32)
|
|
v[0, 0] = 1.0
|
|
v[1, 1] = 1.0
|
|
|
|
# Delta is BA = u v = diag(1, 1, 0, ...)
|
|
lora_matrices = LoRAMatrices(
|
|
layer_name="fc1.weight",
|
|
lora_B=u,
|
|
lora_A=v,
|
|
rank=2,
|
|
explained_variance=1.0,
|
|
singular_values=torch.tensor([1.0, 1.0]),
|
|
original_shape=(32, 32),
|
|
)
|
|
|
|
# 1. Step update exactly in the subspace of lora_matrices
|
|
step_update_aligned = torch.zeros(32, 32)
|
|
step_update_aligned[0, 0] = 2.0
|
|
step_update_aligned[1, 1] = 2.0
|
|
|
|
alignment = aligner.compute_subspace_alignment(
|
|
lora_matrices, step_update_aligned
|
|
)
|
|
# Cosine similarity should be 1.0 (since the direction is fully aligned)
|
|
self.assertAlmostEqual(alignment, 1.0, places=5)
|
|
|
|
# 2. Step update orthogonal to the subspace
|
|
step_update_ortho = torch.zeros(32, 32)
|
|
step_update_ortho[2, 2] = 1.0
|
|
|
|
alignment_ortho = aligner.compute_subspace_alignment(
|
|
lora_matrices, step_update_ortho
|
|
)
|
|
self.assertAlmostEqual(alignment_ortho, 0.0, places=5)
|
|
|
|
def test_domain_profiling(self) -> None:
|
|
aligner = ExpertManifoldAligner(self.complex_model)
|
|
|
|
# Create dummy adapter with layer concentrated in early self_attn (Statute recall)
|
|
adapter_statute = ExpertAdapter(
|
|
adapter_id="test_statute",
|
|
step=1,
|
|
layers={
|
|
"layers.0.self_attn.weight": LoRAMatrices(
|
|
layer_name="layers.0.self_attn.weight",
|
|
lora_B=torch.randn(32, 4),
|
|
lora_A=torch.randn(4, 32),
|
|
rank=4,
|
|
explained_variance=0.9,
|
|
singular_values=torch.ones(4) * 10.0, # high energy
|
|
original_shape=(32, 32),
|
|
)
|
|
},
|
|
)
|
|
|
|
profile = aligner.profile_adapter(adapter_statute)
|
|
self.assertGreater(profile["statute_recall"], profile["logic_reasoning"])
|
|
self.assertGreater(profile["statute_recall"], profile["style_gutachtenstil"])
|
|
|
|
tags = aligner.tag_adapter(adapter_statute)
|
|
self.assertIn("statute_recall", tags)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|