# -*- coding: utf-8 -*- """Unit tests for Phase 5: telemetry_aggregator and run_phase5_validation.""" from __future__ import annotations import os import sys import tempfile import unittest import torch _root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.join(_root, "python")) from parasitic_qlora import ExpertAdapter, LoRAMatrices # noqa: E402 from telemetry_aggregator import ( # noqa: E402 AdapterLibraryAggregator, aggregate_from_paths, ) from run_phase5_validation import ( # noqa: E402 KlausurCase, CaseResult, score_output, aggregate_domain_scores, VALIDATION_SUITE, ) def _make_adapter( adapter_id: str, domain_tags: list[str], avg_ev: float = 0.95, num_layers: int = 2, ) -> ExpertAdapter: """Creates a minimal ExpertAdapter for testing.""" layers = {} for i in range(num_layers): name = f"layer.{i}.weight" r = 4 d, k = 16, 8 layers[name] = LoRAMatrices( layer_name=name, lora_A=torch.randn(r, k), lora_B=torch.randn(d, r), singular_values=torch.ones(r), rank=r, explained_variance=avg_ev, original_shape=(d, k), ) adapter = ExpertAdapter( adapter_id=adapter_id, step=0, layers=layers, avg_explained_variance=avg_ev, optimizer_type="AdamW", extraction_trigger="test", domain_tags=domain_tags, ) return adapter class TestAdapterLibraryAggregator(unittest.TestCase): def _make_library( self, n: int, domain: str, id_prefix: str = "a" ) -> list[ExpertAdapter]: return [_make_adapter(f"{id_prefix}_{i}", [domain]) for i in range(n)] def test_deduplication(self) -> None: """Adding the same adapter_id twice should only keep one copy.""" lib = self._make_library(3, "statute_recall", "dup") agg = AdapterLibraryAggregator() agg.add_library(lib) agg.add_library(lib) # second add should be all rejected self.assertEqual(len(agg.merged), 3) def test_min_ev_filter(self) -> None: """Adapters below min_ev should be rejected.""" good = _make_adapter("good_1", ["logic_reasoning"], avg_ev=0.95) bad = _make_adapter("bad_1", ["logic_reasoning"], avg_ev=0.70) agg = AdapterLibraryAggregator(min_explained_variance=0.90) accepted, rejected = agg.add_library([good, bad]) self.assertEqual(accepted, 1) self.assertEqual(rejected, 1) def test_rebalance_caps_per_domain(self) -> None: """rebalance() should cap adapters per domain at max_adapters_per_domain.""" statute_lib = self._make_library(15, "statute_recall", "s") logic_lib = self._make_library(3, "logic_reasoning", "l") agg = AdapterLibraryAggregator(max_adapters_per_domain=5) agg.add_library(statute_lib) agg.add_library(logic_lib) agg.rebalance() statute_count = sum(1 for a in agg.merged if "statute_recall" in a.domain_tags) self.assertLessEqual(statute_count, 5) def test_coverage_report_keys(self) -> None: """coverage_report() must return all required keys.""" lib = self._make_library(2, "statute_recall") agg = AdapterLibraryAggregator() agg.add_library(lib) report = agg.coverage_report() self.assertIn("total_adapters", report) self.assertIn("domain_coverage", report) self.assertIn("balance_score", report) self.assertIn("avg_explained_variance", report) def test_save_and_reload(self) -> None: """save() must produce a file loadable by ParasiticQLoRAExtractor.""" from parasitic_qlora import ParasiticQLoRAExtractor lib = self._make_library(3, "logic_reasoning", "save") agg = AdapterLibraryAggregator() agg.add_library(lib) with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: path = f.name try: agg.save(path) loaded = ParasiticQLoRAExtractor.load_library(path) self.assertEqual(len(loaded), 3) finally: os.unlink(path) def test_load_and_merge_missing_file(self) -> None: """load_and_merge should skip missing files without crashing.""" agg = AdapterLibraryAggregator() stats = agg.load_and_merge(["/nonexistent/path.pt"]) self.assertEqual(stats["files"], 0) def test_aggregate_from_paths_two_files(self) -> None: """aggregate_from_paths integrates two library files correctly.""" lib_a = self._make_library(2, "statute_recall", "aa") lib_b = self._make_library(2, "logic_reasoning", "bb") with tempfile.TemporaryDirectory() as tmpdir: path_a = os.path.join(tmpdir, "lib_a.pt") path_b = os.path.join(tmpdir, "lib_b.pt") out_path = os.path.join(tmpdir, "merged.pt") # Save using the same torch.save format as save_library/aggregator for lib, path in [(lib_a, path_a), (lib_b, path_b)]: state = { "config": { "min_rank": 8, "max_rank": 32, "explained_variance_threshold": 0.95, }, "adapters": [a.to_state_dict() for a in lib], "extraction_stats": { "total_extractions": len(lib), "total_extraction_time_s": 0.0, "overhead_pct": 0.0, }, } torch.save(state, path) report = aggregate_from_paths([path_a, path_b], out_path) self.assertEqual(report["total_accepted"], 4) self.assertTrue(os.path.exists(out_path)) class TestScoring(unittest.TestCase): def _make_case(self, keywords: list[str], wrong: list[str]) -> KlausurCase: return KlausurCase( case_id="test_case", description="Test", statutes="§ 1 BGB", correct_keywords=keywords, wrong_patterns=wrong, domain="logic_reasoning", ) def test_perfect_score(self) -> None: """All keywords present, no wrong pattern → score=1.0.""" case = self._make_case(["536b", "Kenntnis"], []) result = score_output(case, "§ 536b BGB schließt Kenntnis des Mieters ein.") self.assertAlmostEqual(result.score, 1.0) def test_wrong_pattern_zeroes_score(self) -> None: """Wrong pattern present → score=0 regardless of keywords.""" case = self._make_case(["536b"], ["kann mindern"]) result = score_output(case, "§ 536b, aber M kann mindern hier.") self.assertEqual(result.score, 0.0) self.assertTrue(result.wrong_hit) def test_partial_keyword_hit(self) -> None: """Only some keywords present → fractional score.""" case = self._make_case(["Kenntnis", "vorbehaltlos", "536b"], []) result = score_output(case, "Kenntnis war vorhanden.") self.assertAlmostEqual(result.score, 1 / 3, places=5) def test_aggregate_domain_scores(self) -> None: """Domain aggregation averages scores correctly.""" results = [ CaseResult("c1", "statute_recall", "", 2, 2, False, 1.0), CaseResult("c2", "statute_recall", "", 1, 2, False, 0.5), CaseResult("c3", "logic_reasoning", "", 0, 1, True, 0.0), ] agg = aggregate_domain_scores(results) self.assertAlmostEqual(agg["statute_recall"], 0.75) self.assertAlmostEqual(agg["logic_reasoning"], 0.0) def test_validation_suite_completeness(self) -> None: """VALIDATION_SUITE covers all 3 domains.""" domains = {c.domain for c in VALIDATION_SUITE} self.assertIn("statute_recall", domains) self.assertIn("logic_reasoning", domains) self.assertIn("style_gutachtenstil", domains) # At least 2 cases per domain for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]: count = sum(1 for c in VALIDATION_SUITE if c.domain == domain) self.assertGreaterEqual(count, 2, f"Too few cases for domain {domain}") if __name__ == "__main__": unittest.main()