feat: complete phase 5 - Klausuren validation with MoE routing, dynamic dimensional matching, and full mypy type safety
This commit is contained in:
217
tests/test_phase5_validation.py
Normal file
217
tests/test_phase5_validation.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for Phase 5: telemetry_aggregator and run_phase5_validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.join(_root, "python"))
|
||||
|
||||
from parasitic_qlora import ExpertAdapter, LoRAMatrices # noqa: E402
|
||||
from telemetry_aggregator import ( # noqa: E402
|
||||
AdapterLibraryAggregator,
|
||||
aggregate_from_paths,
|
||||
)
|
||||
from run_phase5_validation import ( # noqa: E402
|
||||
KlausurCase,
|
||||
CaseResult,
|
||||
score_output,
|
||||
aggregate_domain_scores,
|
||||
VALIDATION_SUITE,
|
||||
)
|
||||
|
||||
|
||||
def _make_adapter(
|
||||
adapter_id: str,
|
||||
domain_tags: list[str],
|
||||
avg_ev: float = 0.95,
|
||||
num_layers: int = 2,
|
||||
) -> ExpertAdapter:
|
||||
"""Creates a minimal ExpertAdapter for testing."""
|
||||
layers = {}
|
||||
for i in range(num_layers):
|
||||
name = f"layer.{i}.weight"
|
||||
r = 4
|
||||
d, k = 16, 8
|
||||
layers[name] = LoRAMatrices(
|
||||
layer_name=name,
|
||||
lora_A=torch.randn(r, k),
|
||||
lora_B=torch.randn(d, r),
|
||||
singular_values=torch.ones(r),
|
||||
rank=r,
|
||||
explained_variance=avg_ev,
|
||||
original_shape=(d, k),
|
||||
)
|
||||
adapter = ExpertAdapter(
|
||||
adapter_id=adapter_id,
|
||||
step=0,
|
||||
layers=layers,
|
||||
avg_explained_variance=avg_ev,
|
||||
optimizer_type="AdamW",
|
||||
extraction_trigger="test",
|
||||
domain_tags=domain_tags,
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
||||
class TestAdapterLibraryAggregator(unittest.TestCase):
|
||||
def _make_library(
|
||||
self, n: int, domain: str, id_prefix: str = "a"
|
||||
) -> list[ExpertAdapter]:
|
||||
return [_make_adapter(f"{id_prefix}_{i}", [domain]) for i in range(n)]
|
||||
|
||||
def test_deduplication(self) -> None:
|
||||
"""Adding the same adapter_id twice should only keep one copy."""
|
||||
lib = self._make_library(3, "statute_recall", "dup")
|
||||
agg = AdapterLibraryAggregator()
|
||||
agg.add_library(lib)
|
||||
agg.add_library(lib) # second add should be all rejected
|
||||
self.assertEqual(len(agg.merged), 3)
|
||||
|
||||
def test_min_ev_filter(self) -> None:
|
||||
"""Adapters below min_ev should be rejected."""
|
||||
good = _make_adapter("good_1", ["logic_reasoning"], avg_ev=0.95)
|
||||
bad = _make_adapter("bad_1", ["logic_reasoning"], avg_ev=0.70)
|
||||
agg = AdapterLibraryAggregator(min_explained_variance=0.90)
|
||||
accepted, rejected = agg.add_library([good, bad])
|
||||
self.assertEqual(accepted, 1)
|
||||
self.assertEqual(rejected, 1)
|
||||
|
||||
def test_rebalance_caps_per_domain(self) -> None:
|
||||
"""rebalance() should cap adapters per domain at max_adapters_per_domain."""
|
||||
statute_lib = self._make_library(15, "statute_recall", "s")
|
||||
logic_lib = self._make_library(3, "logic_reasoning", "l")
|
||||
agg = AdapterLibraryAggregator(max_adapters_per_domain=5)
|
||||
agg.add_library(statute_lib)
|
||||
agg.add_library(logic_lib)
|
||||
agg.rebalance()
|
||||
statute_count = sum(1 for a in agg.merged if "statute_recall" in a.domain_tags)
|
||||
self.assertLessEqual(statute_count, 5)
|
||||
|
||||
def test_coverage_report_keys(self) -> None:
|
||||
"""coverage_report() must return all required keys."""
|
||||
lib = self._make_library(2, "statute_recall")
|
||||
agg = AdapterLibraryAggregator()
|
||||
agg.add_library(lib)
|
||||
report = agg.coverage_report()
|
||||
self.assertIn("total_adapters", report)
|
||||
self.assertIn("domain_coverage", report)
|
||||
self.assertIn("balance_score", report)
|
||||
self.assertIn("avg_explained_variance", report)
|
||||
|
||||
def test_save_and_reload(self) -> None:
|
||||
"""save() must produce a file loadable by ParasiticQLoRAExtractor."""
|
||||
from parasitic_qlora import ParasiticQLoRAExtractor
|
||||
|
||||
lib = self._make_library(3, "logic_reasoning", "save")
|
||||
agg = AdapterLibraryAggregator()
|
||||
agg.add_library(lib)
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
agg.save(path)
|
||||
loaded = ParasiticQLoRAExtractor.load_library(path)
|
||||
self.assertEqual(len(loaded), 3)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_load_and_merge_missing_file(self) -> None:
|
||||
"""load_and_merge should skip missing files without crashing."""
|
||||
agg = AdapterLibraryAggregator()
|
||||
stats = agg.load_and_merge(["/nonexistent/path.pt"])
|
||||
self.assertEqual(stats["files"], 0)
|
||||
|
||||
def test_aggregate_from_paths_two_files(self) -> None:
|
||||
"""aggregate_from_paths integrates two library files correctly."""
|
||||
|
||||
lib_a = self._make_library(2, "statute_recall", "aa")
|
||||
lib_b = self._make_library(2, "logic_reasoning", "bb")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path_a = os.path.join(tmpdir, "lib_a.pt")
|
||||
path_b = os.path.join(tmpdir, "lib_b.pt")
|
||||
out_path = os.path.join(tmpdir, "merged.pt")
|
||||
|
||||
# Save using the same torch.save format as save_library/aggregator
|
||||
for lib, path in [(lib_a, path_a), (lib_b, path_b)]:
|
||||
state = {
|
||||
"config": {
|
||||
"min_rank": 8,
|
||||
"max_rank": 32,
|
||||
"explained_variance_threshold": 0.95,
|
||||
},
|
||||
"adapters": [a.to_state_dict() for a in lib],
|
||||
"extraction_stats": {
|
||||
"total_extractions": len(lib),
|
||||
"total_extraction_time_s": 0.0,
|
||||
"overhead_pct": 0.0,
|
||||
},
|
||||
}
|
||||
torch.save(state, path)
|
||||
|
||||
report = aggregate_from_paths([path_a, path_b], out_path)
|
||||
self.assertEqual(report["total_accepted"], 4)
|
||||
self.assertTrue(os.path.exists(out_path))
|
||||
|
||||
|
||||
class TestScoring(unittest.TestCase):
|
||||
def _make_case(self, keywords: list[str], wrong: list[str]) -> KlausurCase:
|
||||
return KlausurCase(
|
||||
case_id="test_case",
|
||||
description="Test",
|
||||
statutes="§ 1 BGB",
|
||||
correct_keywords=keywords,
|
||||
wrong_patterns=wrong,
|
||||
domain="logic_reasoning",
|
||||
)
|
||||
|
||||
def test_perfect_score(self) -> None:
|
||||
"""All keywords present, no wrong pattern → score=1.0."""
|
||||
case = self._make_case(["536b", "Kenntnis"], [])
|
||||
result = score_output(case, "§ 536b BGB schließt Kenntnis des Mieters ein.")
|
||||
self.assertAlmostEqual(result.score, 1.0)
|
||||
|
||||
def test_wrong_pattern_zeroes_score(self) -> None:
|
||||
"""Wrong pattern present → score=0 regardless of keywords."""
|
||||
case = self._make_case(["536b"], ["kann mindern"])
|
||||
result = score_output(case, "§ 536b, aber M kann mindern hier.")
|
||||
self.assertEqual(result.score, 0.0)
|
||||
self.assertTrue(result.wrong_hit)
|
||||
|
||||
def test_partial_keyword_hit(self) -> None:
|
||||
"""Only some keywords present → fractional score."""
|
||||
case = self._make_case(["Kenntnis", "vorbehaltlos", "536b"], [])
|
||||
result = score_output(case, "Kenntnis war vorhanden.")
|
||||
self.assertAlmostEqual(result.score, 1 / 3, places=5)
|
||||
|
||||
def test_aggregate_domain_scores(self) -> None:
|
||||
"""Domain aggregation averages scores correctly."""
|
||||
results = [
|
||||
CaseResult("c1", "statute_recall", "", 2, 2, False, 1.0),
|
||||
CaseResult("c2", "statute_recall", "", 1, 2, False, 0.5),
|
||||
CaseResult("c3", "logic_reasoning", "", 0, 1, True, 0.0),
|
||||
]
|
||||
agg = aggregate_domain_scores(results)
|
||||
self.assertAlmostEqual(agg["statute_recall"], 0.75)
|
||||
self.assertAlmostEqual(agg["logic_reasoning"], 0.0)
|
||||
|
||||
def test_validation_suite_completeness(self) -> None:
|
||||
"""VALIDATION_SUITE covers all 3 domains."""
|
||||
domains = {c.domain for c in VALIDATION_SUITE}
|
||||
self.assertIn("statute_recall", domains)
|
||||
self.assertIn("logic_reasoning", domains)
|
||||
self.assertIn("style_gutachtenstil", domains)
|
||||
# At least 2 cases per domain
|
||||
for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]:
|
||||
count = sum(1 for c in VALIDATION_SUITE if c.domain == domain)
|
||||
self.assertGreaterEqual(count, 2, f"Too few cases for domain {domain}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user