Files
FCES-native/python/run_phase5_validation.py

555 lines
20 KiB
Python

# -*- coding: utf-8 -*-
"""Phase 5: End-to-End Klausuren Validation with MoE Routing.
Runs a full validation pipeline:
1. Load pre-built adapter library (from Kaggle prep or training run)
2. Pre-train MoE gate on domain routing signal
3. Activate MoE routing via forward hooks (zero model modification)
4. Evaluate on all 6 legal Klausuren cases (3 domains)
5. Score: statute_recall / logic_reasoning / style_gutachtenstil accuracy
6. Push results to MariaDB + SurrealDB telemetry
7. Save final merged library
Usage:
python python/run_phase5_validation.py --model EleutherAI/pythia-70m
python python/run_phase5_validation.py --model EleutherAI/pythia-70m \\
--library parasitic_adapters_fces_step5.pt \\
--gate-weights kaggle_output/step3_gate_weights.pt
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(_root, "python"))
from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
from adapter_moe_router import ExpertAdapterRouter, LearnableGate # noqa: E402
from telemetry_aggregator import AdapterLibraryAggregator # noqa: E402
try:
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
_TELEMETRY_AVAILABLE = True
except ImportError:
_TELEMETRY_AVAILABLE = False
# ==============================================================================
# KLAUSUREN VALIDATION SUITE
# ==============================================================================
@dataclass
class KlausurCase:
case_id: str
description: str
statutes: str
correct_keywords: List[str] # must appear in model output
wrong_patterns: List[str] # must NOT appear in model output
domain: str
VALIDATION_SUITE: List[KlausurCase] = [
# --- statute_recall ---
KlausurCase(
case_id="miet_01_statute",
description=(
"Mieter M unterschreibt Mietvertrag vorbehaltlos obwohl er den "
"Schimmelfleck kannte. Er mindert Miete um 20%."
),
statutes="§ 536b BGB",
correct_keywords=["536b", "Kenntnis", "vorbehaltlos"],
wrong_patterns=["kann mindern", "hat Recht auf Minderung"],
domain="statute_recall",
),
KlausurCase(
case_id="kauf_01_statute",
description=(
"Kilometerstand manipuliert. Verkäufer beruft sich auf "
"'gekauft wie gesehen'. Klausel wirksam?"
),
statutes="§ 444 BGB",
correct_keywords=["444", "arglistig"],
wrong_patterns=["Klausel wirksam", "muss akzeptieren"],
domain="statute_recall",
),
# --- logic_reasoning ---
KlausurCase(
case_id="miet_02_logic",
description=(
"Toilettenspülung kaputt, Vermieter auf Segeltrip, Mieter beauftragt "
"Notdienst für 300 EUR. Erstattungsanspruch?"
),
statutes="§ 536a Abs. 2 Nr. 2 BGB",
correct_keywords=["Selbstbeseitigungsrecht", "Gefahr im Verzug", "Erstattung"],
wrong_patterns=["kein Anspruch", "muss selbst zahlen"],
domain="logic_reasoning",
),
KlausurCase(
case_id="delikt_01_logic",
description=(
"A fährt betrunken (1,5 Promille), beschädigt geparktes Auto. "
"Kann A sich auf § 827 BGB berufen?"
),
statutes="§ 823 BGB, § 827 BGB",
correct_keywords=["actio libera", "selbst herbeigeführt", "haftet"],
wrong_patterns=["schuldunfähig", "haftet nicht"],
domain="logic_reasoning",
),
# --- style_gutachtenstil ---
KlausurCase(
case_id="kauf_02_style",
description=(
"Lieferant liefert nicht nach Nachfristablauf. Käuferin kauft "
"teurer woanders. Schadensersatz?"
),
statutes="§ 281 BGB, § 286 BGB",
correct_keywords=["Verzug", "Nachfrist", "Deckungsschaden"],
wrong_patterns=["kein Schadensersatz", "muss klagen"],
domain="style_gutachtenstil",
),
KlausurCase(
case_id="miet_03_style",
description=(
"78-jähriger, schwer kranker Mieter. Vermieter kündigt wegen "
"Eigenbedarfs. Sozialklausel greift?"
),
statutes="§ 574 BGB",
correct_keywords=["Sozialklausel", "Härte", "Widerspruch"],
wrong_patterns=["muss ausziehen", "Eigenbedarf schlägt"],
domain="style_gutachtenstil",
),
]
# ==============================================================================
# VALIDATION SCORING
# ==============================================================================
@dataclass
class CaseResult:
case_id: str
domain: str
output: str
keyword_hits: int
keyword_total: int
wrong_hit: bool
score: float # 0..1
def score_output(case: KlausurCase, output: str) -> CaseResult:
"""Scores a model output against correct keywords and wrong patterns."""
out_lower = output.lower()
keyword_hits = sum(1 for kw in case.correct_keywords if kw.lower() in out_lower)
wrong_hit = any(p.lower() in out_lower for p in case.wrong_patterns)
# Score: fraction of keywords hit, penalised by wrong patterns
raw = keyword_hits / max(1, len(case.correct_keywords))
score = raw * (0.0 if wrong_hit else 1.0)
return CaseResult(
case_id=case.case_id,
domain=case.domain,
output=output,
keyword_hits=keyword_hits,
keyword_total=len(case.correct_keywords),
wrong_hit=wrong_hit,
score=score,
)
def aggregate_domain_scores(results: List[CaseResult]) -> Dict[str, float]:
"""Average score per domain."""
domain_scores: Dict[str, List[float]] = {}
for r in results:
domain_scores.setdefault(r.domain, []).append(r.score)
return {d: sum(vs) / len(vs) for d, vs in domain_scores.items()}
# ==============================================================================
# MODEL INFERENCE
# ==============================================================================
def run_inference(
model: nn.Module,
tokenizer: Any,
case: KlausurCase,
device: str,
max_new_tokens: int = 60,
) -> str:
"""Generates model output for a Klausur case."""
prompt = (
f"Fall: {case.description}\n"
f"Normen: {case.statutes}\n"
"Gutachten (Subsumtion):\n"
)
enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
input_ids = enc["input_ids"].to(device)
model_any: Any = model
with torch.no_grad():
out = model_any.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)
generated = out[0][input_ids.shape[1] :]
return str(tokenizer.decode(generated, skip_special_tokens=True))
# ==============================================================================
# GATE WARM-UP (use Kaggle pre-trained weights if available)
# ==============================================================================
def load_or_train_gate(
adapter_library: List[ExpertAdapter],
in_features: int,
gate_weights_path: Optional[str],
gate_steps: int = 100,
device: str = "cpu",
) -> LearnableGate:
"""Loads pre-trained gate or trains from scratch on domain supervision."""
num_adapters = len(adapter_library)
gate = LearnableGate(in_features, num_adapters).to(device)
if gate_weights_path and os.path.exists(gate_weights_path):
state = torch.load(gate_weights_path, map_location=device)
# Handle size mismatch (library grew since pre-training)
if list(state.values())[0].shape[0] == num_adapters or True:
try:
gate.load_state_dict(state, strict=False)
print(f" [Gate] Loaded pre-trained weights from {gate_weights_path}")
except RuntimeError as e:
print(f" [Gate] Shape mismatch ({e}), training from scratch...")
else:
return gate # type: ignore[no-any-return, unused-ignore]
# Train from scratch with domain supervision
import torch.nn.functional as F
adapter_targets = []
for adapter in adapter_library:
if "statute_recall" in adapter.domain_tags:
adapter_targets.append(0)
elif "logic_reasoning" in adapter.domain_tags:
adapter_targets.append(1)
else:
adapter_targets.append(2)
gate_opt = torch.optim.AdamW(gate.parameters(), lr=1e-3)
print(f" [Gate] Training gate for {num_adapters} adapters, {gate_steps} steps...")
for step in range(gate_steps):
x = torch.randn(8, in_features, device=device)
dom = adapter_targets[step % num_adapters]
x[:, dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0
target = torch.zeros(8, num_adapters, device=device)
for i, t in enumerate(adapter_targets):
if t == dom:
target[:, i] = 1.0 / max(
1, sum(1 for tt in adapter_targets if tt == dom)
)
gate_opt.zero_grad()
logits = gate(x)
loss = F.kl_div(F.log_softmax(logits, dim=-1), target, reduction="batchmean")
loss.backward() # type: ignore[no-untyped-call, unused-ignore]
gate_opt.step()
if step % 25 == 0:
print(f" Gate step {step:3d}/{gate_steps} | loss={loss.item():.4f}")
result: LearnableGate = gate
return result
# ==============================================================================
# MAIN VALIDATION PIPELINE
# ==============================================================================
def run_phase5(
model_name: str,
library_paths: List[str],
gate_weights_path: Optional[str],
output_dir: str,
max_new_tokens: int = 60,
gate_steps: int = 100,
) -> Dict[str, Any]:
"""Full Phase 5 validation pipeline."""
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # type: ignore[import-untyped, unused-ignore]
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(output_dir, exist_ok=True)
print("\n" + "=" * 70)
print(" PHASE 5: KLAUSUREN VALIDATION WITH MOE ROUTING")
print(f" Model: {model_name} | Device: {device}")
print("=" * 70)
# ── 1. Load model ────────────────────────────────────────────────────────
print("\n[1/6] Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()
config = AutoConfig.from_pretrained(model_name)
in_features: int = getattr(config, "hidden_size", 512)
# ── 2. Load/aggregate adapter library ────────────────────────────────────
print("\n[2/6] Aggregating adapter libraries...")
agg = AdapterLibraryAggregator(
max_adapters_per_domain=10, min_explained_variance=0.85
)
merge_stats = agg.load_and_merge(library_paths)
agg.rebalance()
adapter_library = agg.merged
if not adapter_library:
# Cold-start: extract adapters from current model state
print(" No pre-built library found. Extracting from model (cold-start)...")
extractor = ParasiticQLoRAExtractor(
QLoRAConfig(
min_rank=8,
max_rank=32,
explained_variance_threshold=0.90,
extraction_interval=1,
interesting_point_detection=False,
)
)
extractor.snapshot_base(model)
cold_adapter = extractor.extract_adapters(
model, 0, {"loss": 0.0, "optimizer": "baseline"}
)
aligner = ExpertManifoldAligner(model)
aligner.tag_adapter(cold_adapter)
adapter_library = [cold_adapter]
coverage = agg.coverage_report()
print(
f" Library: {coverage['total_adapters']} adapters | "
f"coverage={coverage['domain_coverage']}"
)
# ── 3. Load/train MoE gate ────────────────────────────────────────────────
print("\n[3/6] Initialising MoE gate...")
gate = load_or_train_gate(
adapter_library,
in_features,
gate_weights_path,
gate_steps=gate_steps,
device=device,
)
# ── 4. Activate MoE routing ───────────────────────────────────────────────
print("\n[4/6] Activating MoE routing via forward hooks...")
router = ExpertAdapterRouter(model, adapter_library, in_features=in_features)
# Inject pre-trained gate weights
router.gate.load_state_dict(gate.state_dict(), strict=False)
router.register_hooks()
print(f" Hooks registered on {len(router.hooks)} layers")
# ── 5. Baseline inference (routing OFF) ───────────────────────────────────
print("\n[5/6] Running validation (baseline vs MoE)...")
router.unregister_hooks()
baseline_results: List[CaseResult] = []
for case in VALIDATION_SUITE:
output = run_inference(model, tokenizer, case, device, max_new_tokens)
result = score_output(case, output)
baseline_results.append(result)
print(
f" [BASELINE] {case.case_id:<25} score={result.score:.2f} "
f"({result.keyword_hits}/{result.keyword_total} kw, "
f"wrong={'YES' if result.wrong_hit else 'no'})"
)
baseline_domain = aggregate_domain_scores(baseline_results)
# MoE routing ON
router.register_hooks()
moe_results: List[CaseResult] = []
for case in VALIDATION_SUITE:
output = run_inference(model, tokenizer, case, device, max_new_tokens)
result = score_output(case, output)
moe_results.append(result)
print(
f" [MoE] {case.case_id:<25} score={result.score:.2f} "
f"({result.keyword_hits}/{result.keyword_total} kw, "
f"wrong={'YES' if result.wrong_hit else 'no'})"
)
router.unregister_hooks()
moe_domain = aggregate_domain_scores(moe_results)
# ── 6. Report & telemetry ─────────────────────────────────────────────────
print("\n[6/6] Compiling results...")
baseline_avg = sum(r.score for r in baseline_results) / len(baseline_results)
moe_avg = sum(r.score for r in moe_results) / len(moe_results)
delta = moe_avg - baseline_avg
print("\n" + "=" * 70)
print(" RESULTS")
print("=" * 70)
print(f" {'Domain':<25} {'Baseline':>10} {'MoE':>10} {'Delta':>10}")
print(" " + "-" * 55)
for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]:
b = baseline_domain.get(domain, 0.0)
m = moe_domain.get(domain, 0.0)
d = m - b
marker = (
" <-- best"
if d
== max(
moe_domain.get(dd, 0.0) - baseline_domain.get(dd, 0.0)
for dd in moe_domain
)
else ""
)
print(f" {domain:<25} {b:>10.3f} {m:>10.3f} {d:>+10.3f}{marker}")
print(" " + "-" * 55)
print(f" {'OVERALL':<25} {baseline_avg:>10.3f} {moe_avg:>10.3f} {delta:>+10.3f}")
print("=" * 70)
# Save merged library
merged_lib_path = os.path.join(output_dir, "phase5_merged_library.pt")
agg.save(merged_lib_path)
# Build result dict
report: Dict[str, Any] = {
"model": model_name,
"device": device,
"adapter_count": len(adapter_library),
"merge_stats": merge_stats,
"library_coverage": coverage,
"baseline": {
"per_case": [
{
"case_id": r.case_id,
"score": r.score,
"keyword_hits": r.keyword_hits,
"wrong_hit": r.wrong_hit,
}
for r in baseline_results
],
"per_domain": baseline_domain,
"avg": baseline_avg,
},
"moe": {
"per_case": [
{
"case_id": r.case_id,
"score": r.score,
"keyword_hits": r.keyword_hits,
"wrong_hit": r.wrong_hit,
}
for r in moe_results
],
"per_domain": moe_domain,
"avg": moe_avg,
},
"delta_avg": delta,
"merged_library_path": merged_lib_path,
}
report_path = os.path.join(output_dir, "phase5_results.json")
with open(report_path, "w", encoding="utf-8") as fh:
json.dump(report, fh, indent=2, ensure_ascii=False)
print(f"\n Report saved: {report_path}")
# Push telemetry
if _TELEMETRY_AVAILABLE:
telemetry_entries = [
("INFO", "phase5_baseline_avg", f"{baseline_avg:.4f}"),
("INFO", "phase5_moe_avg", f"{moe_avg:.4f}"),
("INFO", "phase5_delta", f"{delta:.4f}"),
("INFO", "phase5_adapters", str(len(adapter_library))),
]
for domain, score in moe_domain.items():
telemetry_entries.append(("INFO", f"phase5_moe_{domain}", f"{score:.4f}"))
push_to_surrealdb(telemetry_entries)
push_to_mariadb(telemetry_entries)
return report
# ==============================================================================
# CLI
# ==============================================================================
def main() -> None:
parser = argparse.ArgumentParser(
description="Phase 5: Klausuren Validation with MoE"
)
parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m")
parser.add_argument(
"--library",
nargs="*",
default=[],
help="Paths to .pt adapter library files (can specify multiple)",
)
parser.add_argument(
"--gate-weights",
type=str,
default=None,
help="Path to pre-trained gate weights .pt (from kaggle_prep step 3)",
)
parser.add_argument(
"--output-dir",
type=str,
default="phase5_output",
help="Directory for output files",
)
parser.add_argument("--max-new-tokens", type=int, default=60)
parser.add_argument("--gate-steps", type=int, default=100)
args = parser.parse_args()
# Auto-discover adapter libraries in cwd if none specified
library_paths: List[str] = list(args.library)
if not library_paths:
for fname in os.listdir("."):
if fname.startswith("parasitic_adapters_") and fname.endswith(".pt"):
library_paths.append(fname)
if library_paths:
print(f" Auto-discovered {len(library_paths)} libraries: {library_paths}")
t0 = time.perf_counter()
report = run_phase5(
model_name=args.model,
library_paths=library_paths,
gate_weights_path=args.gate_weights,
output_dir=args.output_dir,
max_new_tokens=args.max_new_tokens,
gate_steps=args.gate_steps,
)
elapsed = time.perf_counter() - t0
print(f"\n Phase 5 complete in {elapsed:.1f}s")
print(f" MoE delta vs baseline: {report['delta_avg']:+.4f}")
if __name__ == "__main__":
main()