Files
FCES-native/tests/test_phase5_validation.py

218 lines
8.2 KiB
Python

# -*- coding: utf-8 -*-
"""Unit tests for Phase 5: telemetry_aggregator and run_phase5_validation."""
from __future__ import annotations
import os
import sys
import tempfile
import unittest
import torch
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(_root, "python"))
from parasitic_qlora import ExpertAdapter, LoRAMatrices # noqa: E402
from telemetry_aggregator import ( # noqa: E402
AdapterLibraryAggregator,
aggregate_from_paths,
)
from run_phase5_validation import ( # noqa: E402
KlausurCase,
CaseResult,
score_output,
aggregate_domain_scores,
VALIDATION_SUITE,
)
def _make_adapter(
adapter_id: str,
domain_tags: list[str],
avg_ev: float = 0.95,
num_layers: int = 2,
) -> ExpertAdapter:
"""Creates a minimal ExpertAdapter for testing."""
layers = {}
for i in range(num_layers):
name = f"layer.{i}.weight"
r = 4
d, k = 16, 8
layers[name] = LoRAMatrices(
layer_name=name,
lora_A=torch.randn(r, k),
lora_B=torch.randn(d, r),
singular_values=torch.ones(r),
rank=r,
explained_variance=avg_ev,
original_shape=(d, k),
)
adapter = ExpertAdapter(
adapter_id=adapter_id,
step=0,
layers=layers,
avg_explained_variance=avg_ev,
optimizer_type="AdamW",
extraction_trigger="test",
domain_tags=domain_tags,
)
return adapter
class TestAdapterLibraryAggregator(unittest.TestCase):
def _make_library(
self, n: int, domain: str, id_prefix: str = "a"
) -> list[ExpertAdapter]:
return [_make_adapter(f"{id_prefix}_{i}", [domain]) for i in range(n)]
def test_deduplication(self) -> None:
"""Adding the same adapter_id twice should only keep one copy."""
lib = self._make_library(3, "statute_recall", "dup")
agg = AdapterLibraryAggregator()
agg.add_library(lib)
agg.add_library(lib) # second add should be all rejected
self.assertEqual(len(agg.merged), 3)
def test_min_ev_filter(self) -> None:
"""Adapters below min_ev should be rejected."""
good = _make_adapter("good_1", ["logic_reasoning"], avg_ev=0.95)
bad = _make_adapter("bad_1", ["logic_reasoning"], avg_ev=0.70)
agg = AdapterLibraryAggregator(min_explained_variance=0.90)
accepted, rejected = agg.add_library([good, bad])
self.assertEqual(accepted, 1)
self.assertEqual(rejected, 1)
def test_rebalance_caps_per_domain(self) -> None:
"""rebalance() should cap adapters per domain at max_adapters_per_domain."""
statute_lib = self._make_library(15, "statute_recall", "s")
logic_lib = self._make_library(3, "logic_reasoning", "l")
agg = AdapterLibraryAggregator(max_adapters_per_domain=5)
agg.add_library(statute_lib)
agg.add_library(logic_lib)
agg.rebalance()
statute_count = sum(1 for a in agg.merged if "statute_recall" in a.domain_tags)
self.assertLessEqual(statute_count, 5)
def test_coverage_report_keys(self) -> None:
"""coverage_report() must return all required keys."""
lib = self._make_library(2, "statute_recall")
agg = AdapterLibraryAggregator()
agg.add_library(lib)
report = agg.coverage_report()
self.assertIn("total_adapters", report)
self.assertIn("domain_coverage", report)
self.assertIn("balance_score", report)
self.assertIn("avg_explained_variance", report)
def test_save_and_reload(self) -> None:
"""save() must produce a file loadable by ParasiticQLoRAExtractor."""
from parasitic_qlora import ParasiticQLoRAExtractor
lib = self._make_library(3, "logic_reasoning", "save")
agg = AdapterLibraryAggregator()
agg.add_library(lib)
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
path = f.name
try:
agg.save(path)
loaded = ParasiticQLoRAExtractor.load_library(path)
self.assertEqual(len(loaded), 3)
finally:
os.unlink(path)
def test_load_and_merge_missing_file(self) -> None:
"""load_and_merge should skip missing files without crashing."""
agg = AdapterLibraryAggregator()
stats = agg.load_and_merge(["/nonexistent/path.pt"])
self.assertEqual(stats["files"], 0)
def test_aggregate_from_paths_two_files(self) -> None:
"""aggregate_from_paths integrates two library files correctly."""
lib_a = self._make_library(2, "statute_recall", "aa")
lib_b = self._make_library(2, "logic_reasoning", "bb")
with tempfile.TemporaryDirectory() as tmpdir:
path_a = os.path.join(tmpdir, "lib_a.pt")
path_b = os.path.join(tmpdir, "lib_b.pt")
out_path = os.path.join(tmpdir, "merged.pt")
# Save using the same torch.save format as save_library/aggregator
for lib, path in [(lib_a, path_a), (lib_b, path_b)]:
state = {
"config": {
"min_rank": 8,
"max_rank": 32,
"explained_variance_threshold": 0.95,
},
"adapters": [a.to_state_dict() for a in lib],
"extraction_stats": {
"total_extractions": len(lib),
"total_extraction_time_s": 0.0,
"overhead_pct": 0.0,
},
}
torch.save(state, path)
report = aggregate_from_paths([path_a, path_b], out_path)
self.assertEqual(report["total_accepted"], 4)
self.assertTrue(os.path.exists(out_path))
class TestScoring(unittest.TestCase):
def _make_case(self, keywords: list[str], wrong: list[str]) -> KlausurCase:
return KlausurCase(
case_id="test_case",
description="Test",
statutes="§ 1 BGB",
correct_keywords=keywords,
wrong_patterns=wrong,
domain="logic_reasoning",
)
def test_perfect_score(self) -> None:
"""All keywords present, no wrong pattern → score=1.0."""
case = self._make_case(["536b", "Kenntnis"], [])
result = score_output(case, "§ 536b BGB schließt Kenntnis des Mieters ein.")
self.assertAlmostEqual(result.score, 1.0)
def test_wrong_pattern_zeroes_score(self) -> None:
"""Wrong pattern present → score=0 regardless of keywords."""
case = self._make_case(["536b"], ["kann mindern"])
result = score_output(case, "§ 536b, aber M kann mindern hier.")
self.assertEqual(result.score, 0.0)
self.assertTrue(result.wrong_hit)
def test_partial_keyword_hit(self) -> None:
"""Only some keywords present → fractional score."""
case = self._make_case(["Kenntnis", "vorbehaltlos", "536b"], [])
result = score_output(case, "Kenntnis war vorhanden.")
self.assertAlmostEqual(result.score, 1 / 3, places=5)
def test_aggregate_domain_scores(self) -> None:
"""Domain aggregation averages scores correctly."""
results = [
CaseResult("c1", "statute_recall", "", 2, 2, False, 1.0),
CaseResult("c2", "statute_recall", "", 1, 2, False, 0.5),
CaseResult("c3", "logic_reasoning", "", 0, 1, True, 0.0),
]
agg = aggregate_domain_scores(results)
self.assertAlmostEqual(agg["statute_recall"], 0.75)
self.assertAlmostEqual(agg["logic_reasoning"], 0.0)
def test_validation_suite_completeness(self) -> None:
"""VALIDATION_SUITE covers all 3 domains."""
domains = {c.domain for c in VALIDATION_SUITE}
self.assertIn("statute_recall", domains)
self.assertIn("logic_reasoning", domains)
self.assertIn("style_gutachtenstil", domains)
# At least 2 cases per domain
for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]:
count = sum(1 for c in VALIDATION_SUITE if c.domain == domain)
self.assertGreaterEqual(count, 2, f"Too few cases for domain {domain}")
if __name__ == "__main__":
unittest.main()