218 lines
8.2 KiB
Python
218 lines
8.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Unit tests for Phase 5: telemetry_aggregator and run_phase5_validation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, os.path.join(_root, "python"))
|
|
|
|
from parasitic_qlora import ExpertAdapter, LoRAMatrices # noqa: E402
|
|
from telemetry_aggregator import ( # noqa: E402
|
|
AdapterLibraryAggregator,
|
|
aggregate_from_paths,
|
|
)
|
|
from run_phase5_validation import ( # noqa: E402
|
|
KlausurCase,
|
|
CaseResult,
|
|
score_output,
|
|
aggregate_domain_scores,
|
|
VALIDATION_SUITE,
|
|
)
|
|
|
|
|
|
def _make_adapter(
|
|
adapter_id: str,
|
|
domain_tags: list[str],
|
|
avg_ev: float = 0.95,
|
|
num_layers: int = 2,
|
|
) -> ExpertAdapter:
|
|
"""Creates a minimal ExpertAdapter for testing."""
|
|
layers = {}
|
|
for i in range(num_layers):
|
|
name = f"layer.{i}.weight"
|
|
r = 4
|
|
d, k = 16, 8
|
|
layers[name] = LoRAMatrices(
|
|
layer_name=name,
|
|
lora_A=torch.randn(r, k),
|
|
lora_B=torch.randn(d, r),
|
|
singular_values=torch.ones(r),
|
|
rank=r,
|
|
explained_variance=avg_ev,
|
|
original_shape=(d, k),
|
|
)
|
|
adapter = ExpertAdapter(
|
|
adapter_id=adapter_id,
|
|
step=0,
|
|
layers=layers,
|
|
avg_explained_variance=avg_ev,
|
|
optimizer_type="AdamW",
|
|
extraction_trigger="test",
|
|
domain_tags=domain_tags,
|
|
)
|
|
return adapter
|
|
|
|
|
|
class TestAdapterLibraryAggregator(unittest.TestCase):
|
|
def _make_library(
|
|
self, n: int, domain: str, id_prefix: str = "a"
|
|
) -> list[ExpertAdapter]:
|
|
return [_make_adapter(f"{id_prefix}_{i}", [domain]) for i in range(n)]
|
|
|
|
def test_deduplication(self) -> None:
|
|
"""Adding the same adapter_id twice should only keep one copy."""
|
|
lib = self._make_library(3, "statute_recall", "dup")
|
|
agg = AdapterLibraryAggregator()
|
|
agg.add_library(lib)
|
|
agg.add_library(lib) # second add should be all rejected
|
|
self.assertEqual(len(agg.merged), 3)
|
|
|
|
def test_min_ev_filter(self) -> None:
|
|
"""Adapters below min_ev should be rejected."""
|
|
good = _make_adapter("good_1", ["logic_reasoning"], avg_ev=0.95)
|
|
bad = _make_adapter("bad_1", ["logic_reasoning"], avg_ev=0.70)
|
|
agg = AdapterLibraryAggregator(min_explained_variance=0.90)
|
|
accepted, rejected = agg.add_library([good, bad])
|
|
self.assertEqual(accepted, 1)
|
|
self.assertEqual(rejected, 1)
|
|
|
|
def test_rebalance_caps_per_domain(self) -> None:
|
|
"""rebalance() should cap adapters per domain at max_adapters_per_domain."""
|
|
statute_lib = self._make_library(15, "statute_recall", "s")
|
|
logic_lib = self._make_library(3, "logic_reasoning", "l")
|
|
agg = AdapterLibraryAggregator(max_adapters_per_domain=5)
|
|
agg.add_library(statute_lib)
|
|
agg.add_library(logic_lib)
|
|
agg.rebalance()
|
|
statute_count = sum(1 for a in agg.merged if "statute_recall" in a.domain_tags)
|
|
self.assertLessEqual(statute_count, 5)
|
|
|
|
def test_coverage_report_keys(self) -> None:
|
|
"""coverage_report() must return all required keys."""
|
|
lib = self._make_library(2, "statute_recall")
|
|
agg = AdapterLibraryAggregator()
|
|
agg.add_library(lib)
|
|
report = agg.coverage_report()
|
|
self.assertIn("total_adapters", report)
|
|
self.assertIn("domain_coverage", report)
|
|
self.assertIn("balance_score", report)
|
|
self.assertIn("avg_explained_variance", report)
|
|
|
|
def test_save_and_reload(self) -> None:
|
|
"""save() must produce a file loadable by ParasiticQLoRAExtractor."""
|
|
from parasitic_qlora import ParasiticQLoRAExtractor
|
|
|
|
lib = self._make_library(3, "logic_reasoning", "save")
|
|
agg = AdapterLibraryAggregator()
|
|
agg.add_library(lib)
|
|
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
|
path = f.name
|
|
try:
|
|
agg.save(path)
|
|
loaded = ParasiticQLoRAExtractor.load_library(path)
|
|
self.assertEqual(len(loaded), 3)
|
|
finally:
|
|
os.unlink(path)
|
|
|
|
def test_load_and_merge_missing_file(self) -> None:
|
|
"""load_and_merge should skip missing files without crashing."""
|
|
agg = AdapterLibraryAggregator()
|
|
stats = agg.load_and_merge(["/nonexistent/path.pt"])
|
|
self.assertEqual(stats["files"], 0)
|
|
|
|
def test_aggregate_from_paths_two_files(self) -> None:
|
|
"""aggregate_from_paths integrates two library files correctly."""
|
|
|
|
lib_a = self._make_library(2, "statute_recall", "aa")
|
|
lib_b = self._make_library(2, "logic_reasoning", "bb")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
path_a = os.path.join(tmpdir, "lib_a.pt")
|
|
path_b = os.path.join(tmpdir, "lib_b.pt")
|
|
out_path = os.path.join(tmpdir, "merged.pt")
|
|
|
|
# Save using the same torch.save format as save_library/aggregator
|
|
for lib, path in [(lib_a, path_a), (lib_b, path_b)]:
|
|
state = {
|
|
"config": {
|
|
"min_rank": 8,
|
|
"max_rank": 32,
|
|
"explained_variance_threshold": 0.95,
|
|
},
|
|
"adapters": [a.to_state_dict() for a in lib],
|
|
"extraction_stats": {
|
|
"total_extractions": len(lib),
|
|
"total_extraction_time_s": 0.0,
|
|
"overhead_pct": 0.0,
|
|
},
|
|
}
|
|
torch.save(state, path)
|
|
|
|
report = aggregate_from_paths([path_a, path_b], out_path)
|
|
self.assertEqual(report["total_accepted"], 4)
|
|
self.assertTrue(os.path.exists(out_path))
|
|
|
|
|
|
class TestScoring(unittest.TestCase):
|
|
def _make_case(self, keywords: list[str], wrong: list[str]) -> KlausurCase:
|
|
return KlausurCase(
|
|
case_id="test_case",
|
|
description="Test",
|
|
statutes="§ 1 BGB",
|
|
correct_keywords=keywords,
|
|
wrong_patterns=wrong,
|
|
domain="logic_reasoning",
|
|
)
|
|
|
|
def test_perfect_score(self) -> None:
|
|
"""All keywords present, no wrong pattern → score=1.0."""
|
|
case = self._make_case(["536b", "Kenntnis"], [])
|
|
result = score_output(case, "§ 536b BGB schließt Kenntnis des Mieters ein.")
|
|
self.assertAlmostEqual(result.score, 1.0)
|
|
|
|
def test_wrong_pattern_zeroes_score(self) -> None:
|
|
"""Wrong pattern present → score=0 regardless of keywords."""
|
|
case = self._make_case(["536b"], ["kann mindern"])
|
|
result = score_output(case, "§ 536b, aber M kann mindern hier.")
|
|
self.assertEqual(result.score, 0.0)
|
|
self.assertTrue(result.wrong_hit)
|
|
|
|
def test_partial_keyword_hit(self) -> None:
|
|
"""Only some keywords present → fractional score."""
|
|
case = self._make_case(["Kenntnis", "vorbehaltlos", "536b"], [])
|
|
result = score_output(case, "Kenntnis war vorhanden.")
|
|
self.assertAlmostEqual(result.score, 1 / 3, places=5)
|
|
|
|
def test_aggregate_domain_scores(self) -> None:
|
|
"""Domain aggregation averages scores correctly."""
|
|
results = [
|
|
CaseResult("c1", "statute_recall", "", 2, 2, False, 1.0),
|
|
CaseResult("c2", "statute_recall", "", 1, 2, False, 0.5),
|
|
CaseResult("c3", "logic_reasoning", "", 0, 1, True, 0.0),
|
|
]
|
|
agg = aggregate_domain_scores(results)
|
|
self.assertAlmostEqual(agg["statute_recall"], 0.75)
|
|
self.assertAlmostEqual(agg["logic_reasoning"], 0.0)
|
|
|
|
def test_validation_suite_completeness(self) -> None:
|
|
"""VALIDATION_SUITE covers all 3 domains."""
|
|
domains = {c.domain for c in VALIDATION_SUITE}
|
|
self.assertIn("statute_recall", domains)
|
|
self.assertIn("logic_reasoning", domains)
|
|
self.assertIn("style_gutachtenstil", domains)
|
|
# At least 2 cases per domain
|
|
for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]:
|
|
count = sum(1 for c in VALIDATION_SUITE if c.domain == domain)
|
|
self.assertGreaterEqual(count, 2, f"Too few cases for domain {domain}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|