diff --git a/python/adapter_moe_router.py b/python/adapter_moe_router.py index d3b19e6..494ab26 100644 --- a/python/adapter_moe_router.py +++ b/python/adapter_moe_router.py @@ -15,11 +15,12 @@ import torch.nn as nn from parasitic_qlora import ExpertAdapter -class LearnableGate(nn.Module): # type: ignore[misc] +class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore] """A lightweight learnable MLP gating network for routing.""" def __init__(self, in_features: int, num_adapters: int) -> None: super().__init__() + self.in_features = in_features self.gate = nn.Sequential( nn.Linear(in_features, in_features // 2), nn.ReLU(), @@ -27,9 +28,16 @@ class LearnableGate(nn.Module): # type: ignore[misc] ) def forward(self, x: torch.Tensor) -> torch.Tensor: - # Input: [..., in_features] + # Input: [..., in_features] or other dimension to be pooled # Output: [..., num_adapters] (unnormalized scores) - return self.gate(x) + if x.shape[-1] != self.in_features: + import torch.nn.functional as F + + orig_shape = x.shape + flat_x = x.view(-1, orig_shape[-1]).unsqueeze(1) + pooled_x = F.adaptive_avg_pool1d(flat_x, self.in_features) + x = pooled_x.squeeze(1).view(*orig_shape[:-1], self.in_features) + return self.gate(x) # type: ignore[no-any-return, unused-ignore] class ExpertAdapterRouter: @@ -157,8 +165,8 @@ class ExpertAdapterRouter: if layer_name in adapter.layers: lm = adapter.layers[layer_name] # Ensure tensors are on the correct device - lora_A = lm.lora_A.to(x.device) - lora_B = lm.lora_B.to(x.device) + lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype) + lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype) # Dynamic scaling: gate_weight for this adapter # weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters] diff --git a/python/run_phase5_validation.py b/python/run_phase5_validation.py new file mode 100644 index 0000000..9ec35bc --- /dev/null +++ b/python/run_phase5_validation.py @@ -0,0 +1,554 @@ +# -*- coding: utf-8 -*- +"""Phase 5: End-to-End Klausuren Validation with MoE Routing. + +Runs a full validation pipeline: +1. Load pre-built adapter library (from Kaggle prep or training run) +2. Pre-train MoE gate on domain routing signal +3. Activate MoE routing via forward hooks (zero model modification) +4. Evaluate on all 6 legal Klausuren cases (3 domains) +5. Score: statute_recall / logic_reasoning / style_gutachtenstil accuracy +6. Push results to MariaDB + SurrealDB telemetry +7. Save final merged library + +Usage: + python python/run_phase5_validation.py --model EleutherAI/pythia-70m + python python/run_phase5_validation.py --model EleutherAI/pythia-70m \\ + --library parasitic_adapters_fces_step5.pt \\ + --gate-weights kaggle_output/step3_gate_weights.pt +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + +_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(_root, "python")) + +from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402 +from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402 +from adapter_moe_router import ExpertAdapterRouter, LearnableGate # noqa: E402 +from telemetry_aggregator import AdapterLibraryAggregator # noqa: E402 + +try: + from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402 + + _TELEMETRY_AVAILABLE = True +except ImportError: + _TELEMETRY_AVAILABLE = False + + +# ============================================================================== +# KLAUSUREN VALIDATION SUITE +# ============================================================================== + + +@dataclass +class KlausurCase: + case_id: str + description: str + statutes: str + correct_keywords: List[str] # must appear in model output + wrong_patterns: List[str] # must NOT appear in model output + domain: str + + +VALIDATION_SUITE: List[KlausurCase] = [ + # --- statute_recall --- + KlausurCase( + case_id="miet_01_statute", + description=( + "Mieter M unterschreibt Mietvertrag vorbehaltlos obwohl er den " + "Schimmelfleck kannte. Er mindert Miete um 20%." + ), + statutes="§ 536b BGB", + correct_keywords=["536b", "Kenntnis", "vorbehaltlos"], + wrong_patterns=["kann mindern", "hat Recht auf Minderung"], + domain="statute_recall", + ), + KlausurCase( + case_id="kauf_01_statute", + description=( + "Kilometerstand manipuliert. Verkäufer beruft sich auf " + "'gekauft wie gesehen'. Klausel wirksam?" + ), + statutes="§ 444 BGB", + correct_keywords=["444", "arglistig"], + wrong_patterns=["Klausel wirksam", "muss akzeptieren"], + domain="statute_recall", + ), + # --- logic_reasoning --- + KlausurCase( + case_id="miet_02_logic", + description=( + "Toilettenspülung kaputt, Vermieter auf Segeltrip, Mieter beauftragt " + "Notdienst für 300 EUR. Erstattungsanspruch?" + ), + statutes="§ 536a Abs. 2 Nr. 2 BGB", + correct_keywords=["Selbstbeseitigungsrecht", "Gefahr im Verzug", "Erstattung"], + wrong_patterns=["kein Anspruch", "muss selbst zahlen"], + domain="logic_reasoning", + ), + KlausurCase( + case_id="delikt_01_logic", + description=( + "A fährt betrunken (1,5 Promille), beschädigt geparktes Auto. " + "Kann A sich auf § 827 BGB berufen?" + ), + statutes="§ 823 BGB, § 827 BGB", + correct_keywords=["actio libera", "selbst herbeigeführt", "haftet"], + wrong_patterns=["schuldunfähig", "haftet nicht"], + domain="logic_reasoning", + ), + # --- style_gutachtenstil --- + KlausurCase( + case_id="kauf_02_style", + description=( + "Lieferant liefert nicht nach Nachfristablauf. Käuferin kauft " + "teurer woanders. Schadensersatz?" + ), + statutes="§ 281 BGB, § 286 BGB", + correct_keywords=["Verzug", "Nachfrist", "Deckungsschaden"], + wrong_patterns=["kein Schadensersatz", "muss klagen"], + domain="style_gutachtenstil", + ), + KlausurCase( + case_id="miet_03_style", + description=( + "78-jähriger, schwer kranker Mieter. Vermieter kündigt wegen " + "Eigenbedarfs. Sozialklausel greift?" + ), + statutes="§ 574 BGB", + correct_keywords=["Sozialklausel", "Härte", "Widerspruch"], + wrong_patterns=["muss ausziehen", "Eigenbedarf schlägt"], + domain="style_gutachtenstil", + ), +] + + +# ============================================================================== +# VALIDATION SCORING +# ============================================================================== + + +@dataclass +class CaseResult: + case_id: str + domain: str + output: str + keyword_hits: int + keyword_total: int + wrong_hit: bool + score: float # 0..1 + + +def score_output(case: KlausurCase, output: str) -> CaseResult: + """Scores a model output against correct keywords and wrong patterns.""" + out_lower = output.lower() + keyword_hits = sum(1 for kw in case.correct_keywords if kw.lower() in out_lower) + wrong_hit = any(p.lower() in out_lower for p in case.wrong_patterns) + + # Score: fraction of keywords hit, penalised by wrong patterns + raw = keyword_hits / max(1, len(case.correct_keywords)) + score = raw * (0.0 if wrong_hit else 1.0) + return CaseResult( + case_id=case.case_id, + domain=case.domain, + output=output, + keyword_hits=keyword_hits, + keyword_total=len(case.correct_keywords), + wrong_hit=wrong_hit, + score=score, + ) + + +def aggregate_domain_scores(results: List[CaseResult]) -> Dict[str, float]: + """Average score per domain.""" + domain_scores: Dict[str, List[float]] = {} + for r in results: + domain_scores.setdefault(r.domain, []).append(r.score) + return {d: sum(vs) / len(vs) for d, vs in domain_scores.items()} + + +# ============================================================================== +# MODEL INFERENCE +# ============================================================================== + + +def run_inference( + model: nn.Module, + tokenizer: Any, + case: KlausurCase, + device: str, + max_new_tokens: int = 60, +) -> str: + """Generates model output for a Klausur case.""" + prompt = ( + f"Fall: {case.description}\n" + f"Normen: {case.statutes}\n" + "Gutachten (Subsumtion):\n" + ) + enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128) + input_ids = enc["input_ids"].to(device) + + model_any: Any = model + with torch.no_grad(): + out = model_any.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + temperature=1.0, + pad_token_id=tokenizer.eos_token_id, + ) + generated = out[0][input_ids.shape[1] :] + return str(tokenizer.decode(generated, skip_special_tokens=True)) + + +# ============================================================================== +# GATE WARM-UP (use Kaggle pre-trained weights if available) +# ============================================================================== + + +def load_or_train_gate( + adapter_library: List[ExpertAdapter], + in_features: int, + gate_weights_path: Optional[str], + gate_steps: int = 100, + device: str = "cpu", +) -> LearnableGate: + """Loads pre-trained gate or trains from scratch on domain supervision.""" + num_adapters = len(adapter_library) + gate = LearnableGate(in_features, num_adapters).to(device) + + if gate_weights_path and os.path.exists(gate_weights_path): + state = torch.load(gate_weights_path, map_location=device) + # Handle size mismatch (library grew since pre-training) + if list(state.values())[0].shape[0] == num_adapters or True: + try: + gate.load_state_dict(state, strict=False) + print(f" [Gate] Loaded pre-trained weights from {gate_weights_path}") + except RuntimeError as e: + print(f" [Gate] Shape mismatch ({e}), training from scratch...") + else: + return gate # type: ignore[no-any-return, unused-ignore] + + # Train from scratch with domain supervision + import torch.nn.functional as F + + adapter_targets = [] + for adapter in adapter_library: + if "statute_recall" in adapter.domain_tags: + adapter_targets.append(0) + elif "logic_reasoning" in adapter.domain_tags: + adapter_targets.append(1) + else: + adapter_targets.append(2) + + gate_opt = torch.optim.AdamW(gate.parameters(), lr=1e-3) + print(f" [Gate] Training gate for {num_adapters} adapters, {gate_steps} steps...") + + for step in range(gate_steps): + x = torch.randn(8, in_features, device=device) + dom = adapter_targets[step % num_adapters] + x[:, dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0 + + target = torch.zeros(8, num_adapters, device=device) + for i, t in enumerate(adapter_targets): + if t == dom: + target[:, i] = 1.0 / max( + 1, sum(1 for tt in adapter_targets if tt == dom) + ) + + gate_opt.zero_grad() + logits = gate(x) + loss = F.kl_div(F.log_softmax(logits, dim=-1), target, reduction="batchmean") + loss.backward() # type: ignore[no-untyped-call, unused-ignore] + gate_opt.step() + + if step % 25 == 0: + print(f" Gate step {step:3d}/{gate_steps} | loss={loss.item():.4f}") + + result: LearnableGate = gate + return result + + +# ============================================================================== +# MAIN VALIDATION PIPELINE +# ============================================================================== + + +def run_phase5( + model_name: str, + library_paths: List[str], + gate_weights_path: Optional[str], + output_dir: str, + max_new_tokens: int = 60, + gate_steps: int = 100, +) -> Dict[str, Any]: + """Full Phase 5 validation pipeline.""" + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # type: ignore[import-untyped, unused-ignore] + + device = "cuda" if torch.cuda.is_available() else "cpu" + os.makedirs(output_dir, exist_ok=True) + + print("\n" + "=" * 70) + print(" PHASE 5: KLAUSUREN VALIDATION WITH MOE ROUTING") + print(f" Model: {model_name} | Device: {device}") + print("=" * 70) + + # ── 1. Load model ──────────────────────────────────────────────────────── + print("\n[1/6] Loading model...") + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + model.eval() + + config = AutoConfig.from_pretrained(model_name) + in_features: int = getattr(config, "hidden_size", 512) + + # ── 2. Load/aggregate adapter library ──────────────────────────────────── + print("\n[2/6] Aggregating adapter libraries...") + agg = AdapterLibraryAggregator( + max_adapters_per_domain=10, min_explained_variance=0.85 + ) + merge_stats = agg.load_and_merge(library_paths) + agg.rebalance() + adapter_library = agg.merged + + if not adapter_library: + # Cold-start: extract adapters from current model state + print(" No pre-built library found. Extracting from model (cold-start)...") + extractor = ParasiticQLoRAExtractor( + QLoRAConfig( + min_rank=8, + max_rank=32, + explained_variance_threshold=0.90, + extraction_interval=1, + interesting_point_detection=False, + ) + ) + extractor.snapshot_base(model) + cold_adapter = extractor.extract_adapters( + model, 0, {"loss": 0.0, "optimizer": "baseline"} + ) + aligner = ExpertManifoldAligner(model) + aligner.tag_adapter(cold_adapter) + adapter_library = [cold_adapter] + + coverage = agg.coverage_report() + print( + f" Library: {coverage['total_adapters']} adapters | " + f"coverage={coverage['domain_coverage']}" + ) + + # ── 3. Load/train MoE gate ──────────────────────────────────────────────── + print("\n[3/6] Initialising MoE gate...") + gate = load_or_train_gate( + adapter_library, + in_features, + gate_weights_path, + gate_steps=gate_steps, + device=device, + ) + + # ── 4. Activate MoE routing ─────────────────────────────────────────────── + print("\n[4/6] Activating MoE routing via forward hooks...") + router = ExpertAdapterRouter(model, adapter_library, in_features=in_features) + # Inject pre-trained gate weights + router.gate.load_state_dict(gate.state_dict(), strict=False) + router.register_hooks() + print(f" Hooks registered on {len(router.hooks)} layers") + + # ── 5. Baseline inference (routing OFF) ─────────────────────────────────── + print("\n[5/6] Running validation (baseline vs MoE)...") + router.unregister_hooks() + + baseline_results: List[CaseResult] = [] + for case in VALIDATION_SUITE: + output = run_inference(model, tokenizer, case, device, max_new_tokens) + result = score_output(case, output) + baseline_results.append(result) + print( + f" [BASELINE] {case.case_id:<25} score={result.score:.2f} " + f"({result.keyword_hits}/{result.keyword_total} kw, " + f"wrong={'YES' if result.wrong_hit else 'no'})" + ) + + baseline_domain = aggregate_domain_scores(baseline_results) + + # MoE routing ON + router.register_hooks() + + moe_results: List[CaseResult] = [] + for case in VALIDATION_SUITE: + output = run_inference(model, tokenizer, case, device, max_new_tokens) + result = score_output(case, output) + moe_results.append(result) + print( + f" [MoE] {case.case_id:<25} score={result.score:.2f} " + f"({result.keyword_hits}/{result.keyword_total} kw, " + f"wrong={'YES' if result.wrong_hit else 'no'})" + ) + + router.unregister_hooks() + moe_domain = aggregate_domain_scores(moe_results) + + # ── 6. Report & telemetry ───────────────────────────────────────────────── + print("\n[6/6] Compiling results...") + + baseline_avg = sum(r.score for r in baseline_results) / len(baseline_results) + moe_avg = sum(r.score for r in moe_results) / len(moe_results) + delta = moe_avg - baseline_avg + + print("\n" + "=" * 70) + print(" RESULTS") + print("=" * 70) + print(f" {'Domain':<25} {'Baseline':>10} {'MoE':>10} {'Delta':>10}") + print(" " + "-" * 55) + for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]: + b = baseline_domain.get(domain, 0.0) + m = moe_domain.get(domain, 0.0) + d = m - b + marker = ( + " <-- best" + if d + == max( + moe_domain.get(dd, 0.0) - baseline_domain.get(dd, 0.0) + for dd in moe_domain + ) + else "" + ) + print(f" {domain:<25} {b:>10.3f} {m:>10.3f} {d:>+10.3f}{marker}") + print(" " + "-" * 55) + print(f" {'OVERALL':<25} {baseline_avg:>10.3f} {moe_avg:>10.3f} {delta:>+10.3f}") + print("=" * 70) + + # Save merged library + merged_lib_path = os.path.join(output_dir, "phase5_merged_library.pt") + agg.save(merged_lib_path) + + # Build result dict + report: Dict[str, Any] = { + "model": model_name, + "device": device, + "adapter_count": len(adapter_library), + "merge_stats": merge_stats, + "library_coverage": coverage, + "baseline": { + "per_case": [ + { + "case_id": r.case_id, + "score": r.score, + "keyword_hits": r.keyword_hits, + "wrong_hit": r.wrong_hit, + } + for r in baseline_results + ], + "per_domain": baseline_domain, + "avg": baseline_avg, + }, + "moe": { + "per_case": [ + { + "case_id": r.case_id, + "score": r.score, + "keyword_hits": r.keyword_hits, + "wrong_hit": r.wrong_hit, + } + for r in moe_results + ], + "per_domain": moe_domain, + "avg": moe_avg, + }, + "delta_avg": delta, + "merged_library_path": merged_lib_path, + } + + report_path = os.path.join(output_dir, "phase5_results.json") + with open(report_path, "w", encoding="utf-8") as fh: + json.dump(report, fh, indent=2, ensure_ascii=False) + print(f"\n Report saved: {report_path}") + + # Push telemetry + if _TELEMETRY_AVAILABLE: + telemetry_entries = [ + ("INFO", "phase5_baseline_avg", f"{baseline_avg:.4f}"), + ("INFO", "phase5_moe_avg", f"{moe_avg:.4f}"), + ("INFO", "phase5_delta", f"{delta:.4f}"), + ("INFO", "phase5_adapters", str(len(adapter_library))), + ] + for domain, score in moe_domain.items(): + telemetry_entries.append(("INFO", f"phase5_moe_{domain}", f"{score:.4f}")) + push_to_surrealdb(telemetry_entries) + push_to_mariadb(telemetry_entries) + + return report + + +# ============================================================================== +# CLI +# ============================================================================== + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Phase 5: Klausuren Validation with MoE" + ) + parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m") + parser.add_argument( + "--library", + nargs="*", + default=[], + help="Paths to .pt adapter library files (can specify multiple)", + ) + parser.add_argument( + "--gate-weights", + type=str, + default=None, + help="Path to pre-trained gate weights .pt (from kaggle_prep step 3)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="phase5_output", + help="Directory for output files", + ) + parser.add_argument("--max-new-tokens", type=int, default=60) + parser.add_argument("--gate-steps", type=int, default=100) + args = parser.parse_args() + + # Auto-discover adapter libraries in cwd if none specified + library_paths: List[str] = list(args.library) + if not library_paths: + for fname in os.listdir("."): + if fname.startswith("parasitic_adapters_") and fname.endswith(".pt"): + library_paths.append(fname) + if library_paths: + print(f" Auto-discovered {len(library_paths)} libraries: {library_paths}") + + t0 = time.perf_counter() + report = run_phase5( + model_name=args.model, + library_paths=library_paths, + gate_weights_path=args.gate_weights, + output_dir=args.output_dir, + max_new_tokens=args.max_new_tokens, + gate_steps=args.gate_steps, + ) + elapsed = time.perf_counter() - t0 + + print(f"\n Phase 5 complete in {elapsed:.1f}s") + print(f" MoE delta vs baseline: {report['delta_avg']:+.4f}") + + +if __name__ == "__main__": + main() diff --git a/python/telemetry_aggregator.py b/python/telemetry_aggregator.py new file mode 100644 index 0000000..c6c91b1 --- /dev/null +++ b/python/telemetry_aggregator.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +"""Multi-Worker Adapter Library Aggregator. + +Merges ExpertAdapter libraries produced by multiple independent workers +(Kaggle/RunPod) into a single, deduplicated, domain-balanced corpus. + +Design: +- Workers write adapter libraries to a shared path (NFS, S3, or a .pt file). +- Aggregator loads all libraries, deduplicates by adapter_id, re-ranks by + domain coverage, and emits a consolidated merged_library.pt. +- Telemetry is pushed to MariaDB for cluster-wide monitoring. +""" + +from __future__ import annotations + +import json +import os +import sys +import time +from typing import Dict, List, Tuple + +import torch + +_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(_root, "python")) + +from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor # noqa: E402 + + +# --------------------------------------------------------------------------- +# Domain balance scoring +# --------------------------------------------------------------------------- + +DOMAIN_KEYS: List[str] = ["statute_recall", "logic_reasoning", "style_gutachtenstil"] + + +def _adapter_domain_score(adapter: ExpertAdapter) -> Dict[str, float]: + """Returns normalised domain score dict for an adapter.""" + tags = set(adapter.domain_tags) + total = max(1, len(tags)) + return {d: (1.0 / total if d in tags else 0.0) for d in DOMAIN_KEYS} + + +def _library_coverage(library: List[ExpertAdapter]) -> Dict[str, float]: + """Fraction of adapters covering each domain (0..1).""" + if not library: + return {d: 0.0 for d in DOMAIN_KEYS} + coverage: Dict[str, float] = {d: 0.0 for d in DOMAIN_KEYS} + for adapter in library: + for d in DOMAIN_KEYS: + if d in adapter.domain_tags: + coverage[d] += 1.0 + return {d: v / len(library) for d, v in coverage.items()} + + +# --------------------------------------------------------------------------- +# Core aggregation logic +# --------------------------------------------------------------------------- + + +class AdapterLibraryAggregator: + """Merges adapter libraries from multiple workers into a balanced corpus. + + Args: + max_adapters_per_domain: Hard cap on adapters per domain in merged lib. + min_explained_variance: Discard adapters below this quality threshold. + """ + + def __init__( + self, + max_adapters_per_domain: int = 20, + min_explained_variance: float = 0.90, + ) -> None: + self.max_adapters_per_domain = max_adapters_per_domain + self.min_explained_variance = min_explained_variance + self._seen_ids: set[str] = set() + self.merged: List[ExpertAdapter] = [] + + def add_library(self, library: List[ExpertAdapter]) -> Tuple[int, int]: + """Merges a worker library into the aggregator. + + Returns: + (accepted, rejected) counts. + """ + accepted = 0 + rejected = 0 + for adapter in library: + if adapter.adapter_id in self._seen_ids: + rejected += 1 + continue + if adapter.avg_explained_variance < self.min_explained_variance: + rejected += 1 + continue + self._seen_ids.add(adapter.adapter_id) + self.merged.append(adapter) + accepted += 1 + return accepted, rejected + + def load_and_merge(self, library_paths: List[str]) -> Dict[str, int]: + """Loads all library files and merges them. + + Returns: + Summary dict with per-file stats. + """ + stats: Dict[str, int] = {"total_accepted": 0, "total_rejected": 0, "files": 0} + for path in library_paths: + if not os.path.exists(path): + print(f" [WARN] Library not found: {path}") + continue + lib = ParasiticQLoRAExtractor.load_library(path) + accepted, rejected = self.add_library(lib) + stats["total_accepted"] += accepted + stats["total_rejected"] += rejected + stats["files"] += 1 + print( + f" Loaded {path}: {len(lib)} adapters " + f"(+{accepted} accepted, {rejected} rejected)" + ) + return stats + + def rebalance(self) -> List[ExpertAdapter]: + """Selects a domain-balanced subset honouring max_adapters_per_domain. + + Strategy: + 1. Sort adapters by avg_explained_variance DESC within each domain. + 2. Round-robin pick from domains until cap is hit. + + Returns: + Balanced merged adapter list (also updates self.merged). + """ + buckets: Dict[str, List[ExpertAdapter]] = {d: [] for d in DOMAIN_KEYS} + untagged: List[ExpertAdapter] = [] + + for adapter in self.merged: + placed = False + for d in DOMAIN_KEYS: + if d in adapter.domain_tags: + buckets[d].append(adapter) + placed = True + break + if not placed: + untagged.append(adapter) + + # Sort each bucket by quality + for d in DOMAIN_KEYS: + buckets[d].sort(key=lambda a: a.avg_explained_variance, reverse=True) + buckets[d] = buckets[d][: self.max_adapters_per_domain] + + # Round-robin assemble + balanced: List[ExpertAdapter] = [] + domain_iters = [iter(buckets[d]) for d in DOMAIN_KEYS] + active = list(range(len(DOMAIN_KEYS))) + while active: + next_active = [] + for idx in active: + try: + balanced.append(next(domain_iters[idx])) + next_active.append(idx) + except StopIteration: + pass + active = next_active + + # Append untagged last (low priority) + balanced.extend(untagged) + self.merged = balanced + return balanced + + def save(self, output_path: str) -> None: + """Saves the merged library using the same format as ParasiticQLoRAExtractor.""" + state = { + "config": { + "min_rank": 0, + "max_rank": 0, + "explained_variance_threshold": self.min_explained_variance, + }, + "adapters": [a.to_state_dict() for a in self.merged], + "extraction_stats": { + "total_extractions": len(self.merged), + "total_extraction_time_s": 0.0, + "overhead_pct": 0.0, + }, + } + torch.save(state, output_path) + print( + f"[Aggregator] Saved merged library: {output_path} ({len(self.merged)} adapters)" + ) + + def coverage_report(self) -> Dict[str, object]: + """Returns a human-readable coverage report.""" + coverage = _library_coverage(self.merged) + return { + "total_adapters": len(self.merged), + "domain_coverage": coverage, + "balance_score": min(coverage.values()) / max(max(coverage.values()), 1e-6), + "avg_explained_variance": ( + sum(a.avg_explained_variance for a in self.merged) + / max(1, len(self.merged)) + ), + } + + +# --------------------------------------------------------------------------- +# CLI entry point for standalone aggregation runs +# --------------------------------------------------------------------------- + + +def aggregate_from_paths( + library_paths: List[str], + output_path: str, + max_per_domain: int = 20, + min_ev: float = 0.90, +) -> Dict[str, object]: + """Top-level helper: load, merge, rebalance, save, report.""" + agg = AdapterLibraryAggregator( + max_adapters_per_domain=max_per_domain, + min_explained_variance=min_ev, + ) + merge_stats = agg.load_and_merge(library_paths) + agg.rebalance() + agg.save(output_path) + report = agg.coverage_report() + report.update(merge_stats) + return report + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser( + description="Merge adapter libraries from multiple workers" + ) + parser.add_argument( + "libraries", nargs="+", help="Paths to .pt adapter library files" + ) + parser.add_argument( + "--output", required=True, help="Output path for merged library" + ) + parser.add_argument("--max-per-domain", type=int, default=20) + parser.add_argument("--min-ev", type=float, default=0.90) + args = parser.parse_args() + + print(f"Aggregating {len(args.libraries)} libraries...") + t0 = time.perf_counter() + report = aggregate_from_paths( + args.libraries, args.output, args.max_per_domain, args.min_ev + ) + elapsed = time.perf_counter() - t0 + + print(f"\nMerged Library Report ({elapsed:.1f}s):") + print(json.dumps(report, indent=2)) + print(f"Saved to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_adapter_moe_router.py b/tests/test_adapter_moe_router.py index b0ab616..3aff443 100644 --- a/tests/test_adapter_moe_router.py +++ b/tests/test_adapter_moe_router.py @@ -13,13 +13,13 @@ from parasitic_qlora import ExpertAdapter, LoRAMatrices from adapter_moe_router import ExpertAdapterRouter -class SimpleModel(nn.Module): # type: ignore[misc] +class SimpleModel(nn.Module): # type: ignore[misc, unused-ignore] def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(32, 16, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.fc1(x) + return self.fc1(x) # type: ignore[no-any-return, unused-ignore] class TestAdapterMoERouter(unittest.TestCase): diff --git a/tests/test_phase5_validation.py b/tests/test_phase5_validation.py new file mode 100644 index 0000000..4832cad --- /dev/null +++ b/tests/test_phase5_validation.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +"""Unit tests for Phase 5: telemetry_aggregator and run_phase5_validation.""" + +from __future__ import annotations + +import os +import sys +import tempfile +import unittest + +import torch + +_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(_root, "python")) + +from parasitic_qlora import ExpertAdapter, LoRAMatrices # noqa: E402 +from telemetry_aggregator import ( # noqa: E402 + AdapterLibraryAggregator, + aggregate_from_paths, +) +from run_phase5_validation import ( # noqa: E402 + KlausurCase, + CaseResult, + score_output, + aggregate_domain_scores, + VALIDATION_SUITE, +) + + +def _make_adapter( + adapter_id: str, + domain_tags: list[str], + avg_ev: float = 0.95, + num_layers: int = 2, +) -> ExpertAdapter: + """Creates a minimal ExpertAdapter for testing.""" + layers = {} + for i in range(num_layers): + name = f"layer.{i}.weight" + r = 4 + d, k = 16, 8 + layers[name] = LoRAMatrices( + layer_name=name, + lora_A=torch.randn(r, k), + lora_B=torch.randn(d, r), + singular_values=torch.ones(r), + rank=r, + explained_variance=avg_ev, + original_shape=(d, k), + ) + adapter = ExpertAdapter( + adapter_id=adapter_id, + step=0, + layers=layers, + avg_explained_variance=avg_ev, + optimizer_type="AdamW", + extraction_trigger="test", + domain_tags=domain_tags, + ) + return adapter + + +class TestAdapterLibraryAggregator(unittest.TestCase): + def _make_library( + self, n: int, domain: str, id_prefix: str = "a" + ) -> list[ExpertAdapter]: + return [_make_adapter(f"{id_prefix}_{i}", [domain]) for i in range(n)] + + def test_deduplication(self) -> None: + """Adding the same adapter_id twice should only keep one copy.""" + lib = self._make_library(3, "statute_recall", "dup") + agg = AdapterLibraryAggregator() + agg.add_library(lib) + agg.add_library(lib) # second add should be all rejected + self.assertEqual(len(agg.merged), 3) + + def test_min_ev_filter(self) -> None: + """Adapters below min_ev should be rejected.""" + good = _make_adapter("good_1", ["logic_reasoning"], avg_ev=0.95) + bad = _make_adapter("bad_1", ["logic_reasoning"], avg_ev=0.70) + agg = AdapterLibraryAggregator(min_explained_variance=0.90) + accepted, rejected = agg.add_library([good, bad]) + self.assertEqual(accepted, 1) + self.assertEqual(rejected, 1) + + def test_rebalance_caps_per_domain(self) -> None: + """rebalance() should cap adapters per domain at max_adapters_per_domain.""" + statute_lib = self._make_library(15, "statute_recall", "s") + logic_lib = self._make_library(3, "logic_reasoning", "l") + agg = AdapterLibraryAggregator(max_adapters_per_domain=5) + agg.add_library(statute_lib) + agg.add_library(logic_lib) + agg.rebalance() + statute_count = sum(1 for a in agg.merged if "statute_recall" in a.domain_tags) + self.assertLessEqual(statute_count, 5) + + def test_coverage_report_keys(self) -> None: + """coverage_report() must return all required keys.""" + lib = self._make_library(2, "statute_recall") + agg = AdapterLibraryAggregator() + agg.add_library(lib) + report = agg.coverage_report() + self.assertIn("total_adapters", report) + self.assertIn("domain_coverage", report) + self.assertIn("balance_score", report) + self.assertIn("avg_explained_variance", report) + + def test_save_and_reload(self) -> None: + """save() must produce a file loadable by ParasiticQLoRAExtractor.""" + from parasitic_qlora import ParasiticQLoRAExtractor + + lib = self._make_library(3, "logic_reasoning", "save") + agg = AdapterLibraryAggregator() + agg.add_library(lib) + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + path = f.name + try: + agg.save(path) + loaded = ParasiticQLoRAExtractor.load_library(path) + self.assertEqual(len(loaded), 3) + finally: + os.unlink(path) + + def test_load_and_merge_missing_file(self) -> None: + """load_and_merge should skip missing files without crashing.""" + agg = AdapterLibraryAggregator() + stats = agg.load_and_merge(["/nonexistent/path.pt"]) + self.assertEqual(stats["files"], 0) + + def test_aggregate_from_paths_two_files(self) -> None: + """aggregate_from_paths integrates two library files correctly.""" + + lib_a = self._make_library(2, "statute_recall", "aa") + lib_b = self._make_library(2, "logic_reasoning", "bb") + + with tempfile.TemporaryDirectory() as tmpdir: + path_a = os.path.join(tmpdir, "lib_a.pt") + path_b = os.path.join(tmpdir, "lib_b.pt") + out_path = os.path.join(tmpdir, "merged.pt") + + # Save using the same torch.save format as save_library/aggregator + for lib, path in [(lib_a, path_a), (lib_b, path_b)]: + state = { + "config": { + "min_rank": 8, + "max_rank": 32, + "explained_variance_threshold": 0.95, + }, + "adapters": [a.to_state_dict() for a in lib], + "extraction_stats": { + "total_extractions": len(lib), + "total_extraction_time_s": 0.0, + "overhead_pct": 0.0, + }, + } + torch.save(state, path) + + report = aggregate_from_paths([path_a, path_b], out_path) + self.assertEqual(report["total_accepted"], 4) + self.assertTrue(os.path.exists(out_path)) + + +class TestScoring(unittest.TestCase): + def _make_case(self, keywords: list[str], wrong: list[str]) -> KlausurCase: + return KlausurCase( + case_id="test_case", + description="Test", + statutes="§ 1 BGB", + correct_keywords=keywords, + wrong_patterns=wrong, + domain="logic_reasoning", + ) + + def test_perfect_score(self) -> None: + """All keywords present, no wrong pattern → score=1.0.""" + case = self._make_case(["536b", "Kenntnis"], []) + result = score_output(case, "§ 536b BGB schließt Kenntnis des Mieters ein.") + self.assertAlmostEqual(result.score, 1.0) + + def test_wrong_pattern_zeroes_score(self) -> None: + """Wrong pattern present → score=0 regardless of keywords.""" + case = self._make_case(["536b"], ["kann mindern"]) + result = score_output(case, "§ 536b, aber M kann mindern hier.") + self.assertEqual(result.score, 0.0) + self.assertTrue(result.wrong_hit) + + def test_partial_keyword_hit(self) -> None: + """Only some keywords present → fractional score.""" + case = self._make_case(["Kenntnis", "vorbehaltlos", "536b"], []) + result = score_output(case, "Kenntnis war vorhanden.") + self.assertAlmostEqual(result.score, 1 / 3, places=5) + + def test_aggregate_domain_scores(self) -> None: + """Domain aggregation averages scores correctly.""" + results = [ + CaseResult("c1", "statute_recall", "", 2, 2, False, 1.0), + CaseResult("c2", "statute_recall", "", 1, 2, False, 0.5), + CaseResult("c3", "logic_reasoning", "", 0, 1, True, 0.0), + ] + agg = aggregate_domain_scores(results) + self.assertAlmostEqual(agg["statute_recall"], 0.75) + self.assertAlmostEqual(agg["logic_reasoning"], 0.0) + + def test_validation_suite_completeness(self) -> None: + """VALIDATION_SUITE covers all 3 domains.""" + domains = {c.domain for c in VALIDATION_SUITE} + self.assertIn("statute_recall", domains) + self.assertIn("logic_reasoning", domains) + self.assertIn("style_gutachtenstil", domains) + # At least 2 cases per domain + for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]: + count = sum(1 for c in VALIDATION_SUITE if c.domain == domain) + self.assertGreaterEqual(count, 2, f"Too few cases for domain {domain}") + + +if __name__ == "__main__": + unittest.main()