# -*- coding: utf-8 -*- """Phase 5: End-to-End Klausuren Validation with MoE Routing. Runs a full validation pipeline: 1. Load pre-built adapter library (from Kaggle prep or training run) 2. Pre-train MoE gate on domain routing signal 3. Activate MoE routing via forward hooks (zero model modification) 4. Evaluate on all 6 legal Klausuren cases (3 domains) 5. Score: statute_recall / logic_reasoning / style_gutachtenstil accuracy 6. Push results to MariaDB + SurrealDB telemetry 7. Save final merged library Usage: python python/run_phase5_validation.py --model EleutherAI/pythia-70m python python/run_phase5_validation.py --model EleutherAI/pythia-70m \\ --library parasitic_adapters_fces_step5.pt \\ --gate-weights kaggle_output/step3_gate_weights.pt """ from __future__ import annotations import argparse import json import os import sys import time from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch import torch.nn as nn _root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.join(_root, "python")) from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402 from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402 from adapter_moe_router import ExpertAdapterRouter, LearnableGate # noqa: E402 from telemetry_aggregator import AdapterLibraryAggregator # noqa: E402 try: from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402 _TELEMETRY_AVAILABLE = True except ImportError: _TELEMETRY_AVAILABLE = False # ============================================================================== # KLAUSUREN VALIDATION SUITE # ============================================================================== @dataclass class KlausurCase: case_id: str description: str statutes: str correct_keywords: List[str] # must appear in model output wrong_patterns: List[str] # must NOT appear in model output domain: str VALIDATION_SUITE: List[KlausurCase] = [ # --- statute_recall --- KlausurCase( case_id="miet_01_statute", description=( "Mieter M unterschreibt Mietvertrag vorbehaltlos obwohl er den " "Schimmelfleck kannte. Er mindert Miete um 20%." ), statutes="§ 536b BGB", correct_keywords=["536b", "Kenntnis", "vorbehaltlos"], wrong_patterns=["kann mindern", "hat Recht auf Minderung"], domain="statute_recall", ), KlausurCase( case_id="kauf_01_statute", description=( "Kilometerstand manipuliert. Verkäufer beruft sich auf " "'gekauft wie gesehen'. Klausel wirksam?" ), statutes="§ 444 BGB", correct_keywords=["444", "arglistig"], wrong_patterns=["Klausel wirksam", "muss akzeptieren"], domain="statute_recall", ), # --- logic_reasoning --- KlausurCase( case_id="miet_02_logic", description=( "Toilettenspülung kaputt, Vermieter auf Segeltrip, Mieter beauftragt " "Notdienst für 300 EUR. Erstattungsanspruch?" ), statutes="§ 536a Abs. 2 Nr. 2 BGB", correct_keywords=["Selbstbeseitigungsrecht", "Gefahr im Verzug", "Erstattung"], wrong_patterns=["kein Anspruch", "muss selbst zahlen"], domain="logic_reasoning", ), KlausurCase( case_id="delikt_01_logic", description=( "A fährt betrunken (1,5 Promille), beschädigt geparktes Auto. " "Kann A sich auf § 827 BGB berufen?" ), statutes="§ 823 BGB, § 827 BGB", correct_keywords=["actio libera", "selbst herbeigeführt", "haftet"], wrong_patterns=["schuldunfähig", "haftet nicht"], domain="logic_reasoning", ), # --- style_gutachtenstil --- KlausurCase( case_id="kauf_02_style", description=( "Lieferant liefert nicht nach Nachfristablauf. Käuferin kauft " "teurer woanders. Schadensersatz?" ), statutes="§ 281 BGB, § 286 BGB", correct_keywords=["Verzug", "Nachfrist", "Deckungsschaden"], wrong_patterns=["kein Schadensersatz", "muss klagen"], domain="style_gutachtenstil", ), KlausurCase( case_id="miet_03_style", description=( "78-jähriger, schwer kranker Mieter. Vermieter kündigt wegen " "Eigenbedarfs. Sozialklausel greift?" ), statutes="§ 574 BGB", correct_keywords=["Sozialklausel", "Härte", "Widerspruch"], wrong_patterns=["muss ausziehen", "Eigenbedarf schlägt"], domain="style_gutachtenstil", ), ] # ============================================================================== # VALIDATION SCORING # ============================================================================== @dataclass class CaseResult: case_id: str domain: str output: str keyword_hits: int keyword_total: int wrong_hit: bool score: float # 0..1 def score_output(case: KlausurCase, output: str) -> CaseResult: """Scores a model output against correct keywords and wrong patterns.""" out_lower = output.lower() keyword_hits = sum(1 for kw in case.correct_keywords if kw.lower() in out_lower) wrong_hit = any(p.lower() in out_lower for p in case.wrong_patterns) # Score: fraction of keywords hit, penalised by wrong patterns raw = keyword_hits / max(1, len(case.correct_keywords)) score = raw * (0.0 if wrong_hit else 1.0) return CaseResult( case_id=case.case_id, domain=case.domain, output=output, keyword_hits=keyword_hits, keyword_total=len(case.correct_keywords), wrong_hit=wrong_hit, score=score, ) def aggregate_domain_scores(results: List[CaseResult]) -> Dict[str, float]: """Average score per domain.""" domain_scores: Dict[str, List[float]] = {} for r in results: domain_scores.setdefault(r.domain, []).append(r.score) return {d: sum(vs) / len(vs) for d, vs in domain_scores.items()} # ============================================================================== # MODEL INFERENCE # ============================================================================== def run_inference( model: nn.Module, tokenizer: Any, case: KlausurCase, device: str, max_new_tokens: int = 60, ) -> str: """Generates model output for a Klausur case.""" prompt = ( f"Fall: {case.description}\n" f"Normen: {case.statutes}\n" "Gutachten (Subsumtion):\n" ) enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128) input_ids = enc["input_ids"].to(device) model_any: Any = model with torch.no_grad(): out = model_any.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, pad_token_id=tokenizer.eos_token_id, ) generated = out[0][input_ids.shape[1] :] return str(tokenizer.decode(generated, skip_special_tokens=True)) # ============================================================================== # GATE WARM-UP (use Kaggle pre-trained weights if available) # ============================================================================== def load_or_train_gate( adapter_library: List[ExpertAdapter], in_features: int, gate_weights_path: Optional[str], gate_steps: int = 100, device: str = "cpu", ) -> LearnableGate: """Loads pre-trained gate or trains from scratch on domain supervision.""" num_adapters = len(adapter_library) gate = LearnableGate(in_features, num_adapters).to(device) if gate_weights_path and os.path.exists(gate_weights_path): state = torch.load(gate_weights_path, map_location=device) # Handle size mismatch (library grew since pre-training) if list(state.values())[0].shape[0] == num_adapters or True: try: gate.load_state_dict(state, strict=False) print(f" [Gate] Loaded pre-trained weights from {gate_weights_path}") except RuntimeError as e: print(f" [Gate] Shape mismatch ({e}), training from scratch...") else: return gate # type: ignore[no-any-return, unused-ignore] # Train from scratch with domain supervision import torch.nn.functional as F adapter_targets = [] for adapter in adapter_library: if "statute_recall" in adapter.domain_tags: adapter_targets.append(0) elif "logic_reasoning" in adapter.domain_tags: adapter_targets.append(1) else: adapter_targets.append(2) gate_opt = torch.optim.AdamW(gate.parameters(), lr=1e-3) print(f" [Gate] Training gate for {num_adapters} adapters, {gate_steps} steps...") for step in range(gate_steps): x = torch.randn(8, in_features, device=device) dom = adapter_targets[step % num_adapters] x[:, dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0 target = torch.zeros(8, num_adapters, device=device) for i, t in enumerate(adapter_targets): if t == dom: target[:, i] = 1.0 / max( 1, sum(1 for tt in adapter_targets if tt == dom) ) gate_opt.zero_grad() logits = gate(x) loss = F.kl_div(F.log_softmax(logits, dim=-1), target, reduction="batchmean") loss.backward() # type: ignore[no-untyped-call, unused-ignore] gate_opt.step() if step % 25 == 0: print(f" Gate step {step:3d}/{gate_steps} | loss={loss.item():.4f}") result: LearnableGate = gate return result # ============================================================================== # MAIN VALIDATION PIPELINE # ============================================================================== def run_phase5( model_name: str, library_paths: List[str], gate_weights_path: Optional[str], output_dir: str, max_new_tokens: int = 60, gate_steps: int = 100, ) -> Dict[str, Any]: """Full Phase 5 validation pipeline.""" from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # type: ignore[import-untyped, unused-ignore] device = "cuda" if torch.cuda.is_available() else "cpu" os.makedirs(output_dir, exist_ok=True) print("\n" + "=" * 70) print(" PHASE 5: KLAUSUREN VALIDATION WITH MOE ROUTING") print(f" Model: {model_name} | Device: {device}") print("=" * 70) # ── 1. Load model ──────────────────────────────────────────────────────── print("\n[1/6] Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name).to(device) model.eval() config = AutoConfig.from_pretrained(model_name) in_features: int = getattr(config, "hidden_size", 512) # ── 2. Load/aggregate adapter library ──────────────────────────────────── print("\n[2/6] Aggregating adapter libraries...") agg = AdapterLibraryAggregator( max_adapters_per_domain=10, min_explained_variance=0.85 ) merge_stats = agg.load_and_merge(library_paths) agg.rebalance() adapter_library = agg.merged if not adapter_library: # Cold-start: extract adapters from current model state print(" No pre-built library found. Extracting from model (cold-start)...") extractor = ParasiticQLoRAExtractor( QLoRAConfig( min_rank=8, max_rank=32, explained_variance_threshold=0.90, extraction_interval=1, interesting_point_detection=False, ) ) extractor.snapshot_base(model) cold_adapter = extractor.extract_adapters( model, 0, {"loss": 0.0, "optimizer": "baseline"} ) aligner = ExpertManifoldAligner(model) aligner.tag_adapter(cold_adapter) adapter_library = [cold_adapter] coverage = agg.coverage_report() print( f" Library: {coverage['total_adapters']} adapters | " f"coverage={coverage['domain_coverage']}" ) # ── 3. Load/train MoE gate ──────────────────────────────────────────────── print("\n[3/6] Initialising MoE gate...") gate = load_or_train_gate( adapter_library, in_features, gate_weights_path, gate_steps=gate_steps, device=device, ) # ── 4. Activate MoE routing ─────────────────────────────────────────────── print("\n[4/6] Activating MoE routing via forward hooks...") router = ExpertAdapterRouter(model, adapter_library, in_features=in_features) # Inject pre-trained gate weights router.gate.load_state_dict(gate.state_dict(), strict=False) router.register_hooks() print(f" Hooks registered on {len(router.hooks)} layers") # ── 5. Baseline inference (routing OFF) ─────────────────────────────────── print("\n[5/6] Running validation (baseline vs MoE)...") router.unregister_hooks() baseline_results: List[CaseResult] = [] for case in VALIDATION_SUITE: output = run_inference(model, tokenizer, case, device, max_new_tokens) result = score_output(case, output) baseline_results.append(result) print( f" [BASELINE] {case.case_id:<25} score={result.score:.2f} " f"({result.keyword_hits}/{result.keyword_total} kw, " f"wrong={'YES' if result.wrong_hit else 'no'})" ) baseline_domain = aggregate_domain_scores(baseline_results) # MoE routing ON router.register_hooks() moe_results: List[CaseResult] = [] for case in VALIDATION_SUITE: output = run_inference(model, tokenizer, case, device, max_new_tokens) result = score_output(case, output) moe_results.append(result) print( f" [MoE] {case.case_id:<25} score={result.score:.2f} " f"({result.keyword_hits}/{result.keyword_total} kw, " f"wrong={'YES' if result.wrong_hit else 'no'})" ) router.unregister_hooks() moe_domain = aggregate_domain_scores(moe_results) # ── 6. Report & telemetry ───────────────────────────────────────────────── print("\n[6/6] Compiling results...") baseline_avg = sum(r.score for r in baseline_results) / len(baseline_results) moe_avg = sum(r.score for r in moe_results) / len(moe_results) delta = moe_avg - baseline_avg print("\n" + "=" * 70) print(" RESULTS") print("=" * 70) print(f" {'Domain':<25} {'Baseline':>10} {'MoE':>10} {'Delta':>10}") print(" " + "-" * 55) for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]: b = baseline_domain.get(domain, 0.0) m = moe_domain.get(domain, 0.0) d = m - b marker = ( " <-- best" if d == max( moe_domain.get(dd, 0.0) - baseline_domain.get(dd, 0.0) for dd in moe_domain ) else "" ) print(f" {domain:<25} {b:>10.3f} {m:>10.3f} {d:>+10.3f}{marker}") print(" " + "-" * 55) print(f" {'OVERALL':<25} {baseline_avg:>10.3f} {moe_avg:>10.3f} {delta:>+10.3f}") print("=" * 70) # Save merged library merged_lib_path = os.path.join(output_dir, "phase5_merged_library.pt") agg.save(merged_lib_path) # Build result dict report: Dict[str, Any] = { "model": model_name, "device": device, "adapter_count": len(adapter_library), "merge_stats": merge_stats, "library_coverage": coverage, "baseline": { "per_case": [ { "case_id": r.case_id, "score": r.score, "keyword_hits": r.keyword_hits, "wrong_hit": r.wrong_hit, } for r in baseline_results ], "per_domain": baseline_domain, "avg": baseline_avg, }, "moe": { "per_case": [ { "case_id": r.case_id, "score": r.score, "keyword_hits": r.keyword_hits, "wrong_hit": r.wrong_hit, } for r in moe_results ], "per_domain": moe_domain, "avg": moe_avg, }, "delta_avg": delta, "merged_library_path": merged_lib_path, } report_path = os.path.join(output_dir, "phase5_results.json") with open(report_path, "w", encoding="utf-8") as fh: json.dump(report, fh, indent=2, ensure_ascii=False) print(f"\n Report saved: {report_path}") # Push telemetry if _TELEMETRY_AVAILABLE: telemetry_entries = [ ("INFO", "phase5_baseline_avg", f"{baseline_avg:.4f}"), ("INFO", "phase5_moe_avg", f"{moe_avg:.4f}"), ("INFO", "phase5_delta", f"{delta:.4f}"), ("INFO", "phase5_adapters", str(len(adapter_library))), ] for domain, score in moe_domain.items(): telemetry_entries.append(("INFO", f"phase5_moe_{domain}", f"{score:.4f}")) push_to_surrealdb(telemetry_entries) push_to_mariadb(telemetry_entries) return report # ============================================================================== # CLI # ============================================================================== def main() -> None: parser = argparse.ArgumentParser( description="Phase 5: Klausuren Validation with MoE" ) parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m") parser.add_argument( "--library", nargs="*", default=[], help="Paths to .pt adapter library files (can specify multiple)", ) parser.add_argument( "--gate-weights", type=str, default=None, help="Path to pre-trained gate weights .pt (from kaggle_prep step 3)", ) parser.add_argument( "--output-dir", type=str, default="phase5_output", help="Directory for output files", ) parser.add_argument("--max-new-tokens", type=int, default=60) parser.add_argument("--gate-steps", type=int, default=100) args = parser.parse_args() # Auto-discover adapter libraries in cwd if none specified library_paths: List[str] = list(args.library) if not library_paths: for fname in os.listdir("."): if fname.startswith("parasitic_adapters_") and fname.endswith(".pt"): library_paths.append(fname) if library_paths: print(f" Auto-discovered {len(library_paths)} libraries: {library_paths}") t0 = time.perf_counter() report = run_phase5( model_name=args.model, library_paths=library_paths, gate_weights_path=args.gate_weights, output_dir=args.output_dir, max_new_tokens=args.max_new_tokens, gate_steps=args.gate_steps, ) elapsed = time.perf_counter() - t0 print(f"\n Phase 5 complete in {elapsed:.1f}s") print(f" MoE delta vs baseline: {report['delta_avg']:+.4f}") if __name__ == "__main__": main()