555 lines
20 KiB
Python
555 lines
20 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Phase 5: End-to-End Klausuren Validation with MoE Routing.
|
|
|
|
Runs a full validation pipeline:
|
|
1. Load pre-built adapter library (from Kaggle prep or training run)
|
|
2. Pre-train MoE gate on domain routing signal
|
|
3. Activate MoE routing via forward hooks (zero model modification)
|
|
4. Evaluate on all 6 legal Klausuren cases (3 domains)
|
|
5. Score: statute_recall / logic_reasoning / style_gutachtenstil accuracy
|
|
6. Push results to MariaDB + SurrealDB telemetry
|
|
7. Save final merged library
|
|
|
|
Usage:
|
|
python python/run_phase5_validation.py --model EleutherAI/pythia-70m
|
|
python python/run_phase5_validation.py --model EleutherAI/pythia-70m \\
|
|
--library parasitic_adapters_fces_step5.pt \\
|
|
--gate-weights kaggle_output/step3_gate_weights.pt
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, os.path.join(_root, "python"))
|
|
|
|
from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
|
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
|
|
from adapter_moe_router import ExpertAdapterRouter, LearnableGate # noqa: E402
|
|
from telemetry_aggregator import AdapterLibraryAggregator # noqa: E402
|
|
|
|
try:
|
|
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
|
|
|
_TELEMETRY_AVAILABLE = True
|
|
except ImportError:
|
|
_TELEMETRY_AVAILABLE = False
|
|
|
|
|
|
# ==============================================================================
|
|
# KLAUSUREN VALIDATION SUITE
|
|
# ==============================================================================
|
|
|
|
|
|
@dataclass
|
|
class KlausurCase:
|
|
case_id: str
|
|
description: str
|
|
statutes: str
|
|
correct_keywords: List[str] # must appear in model output
|
|
wrong_patterns: List[str] # must NOT appear in model output
|
|
domain: str
|
|
|
|
|
|
VALIDATION_SUITE: List[KlausurCase] = [
|
|
# --- statute_recall ---
|
|
KlausurCase(
|
|
case_id="miet_01_statute",
|
|
description=(
|
|
"Mieter M unterschreibt Mietvertrag vorbehaltlos obwohl er den "
|
|
"Schimmelfleck kannte. Er mindert Miete um 20%."
|
|
),
|
|
statutes="§ 536b BGB",
|
|
correct_keywords=["536b", "Kenntnis", "vorbehaltlos"],
|
|
wrong_patterns=["kann mindern", "hat Recht auf Minderung"],
|
|
domain="statute_recall",
|
|
),
|
|
KlausurCase(
|
|
case_id="kauf_01_statute",
|
|
description=(
|
|
"Kilometerstand manipuliert. Verkäufer beruft sich auf "
|
|
"'gekauft wie gesehen'. Klausel wirksam?"
|
|
),
|
|
statutes="§ 444 BGB",
|
|
correct_keywords=["444", "arglistig"],
|
|
wrong_patterns=["Klausel wirksam", "muss akzeptieren"],
|
|
domain="statute_recall",
|
|
),
|
|
# --- logic_reasoning ---
|
|
KlausurCase(
|
|
case_id="miet_02_logic",
|
|
description=(
|
|
"Toilettenspülung kaputt, Vermieter auf Segeltrip, Mieter beauftragt "
|
|
"Notdienst für 300 EUR. Erstattungsanspruch?"
|
|
),
|
|
statutes="§ 536a Abs. 2 Nr. 2 BGB",
|
|
correct_keywords=["Selbstbeseitigungsrecht", "Gefahr im Verzug", "Erstattung"],
|
|
wrong_patterns=["kein Anspruch", "muss selbst zahlen"],
|
|
domain="logic_reasoning",
|
|
),
|
|
KlausurCase(
|
|
case_id="delikt_01_logic",
|
|
description=(
|
|
"A fährt betrunken (1,5 Promille), beschädigt geparktes Auto. "
|
|
"Kann A sich auf § 827 BGB berufen?"
|
|
),
|
|
statutes="§ 823 BGB, § 827 BGB",
|
|
correct_keywords=["actio libera", "selbst herbeigeführt", "haftet"],
|
|
wrong_patterns=["schuldunfähig", "haftet nicht"],
|
|
domain="logic_reasoning",
|
|
),
|
|
# --- style_gutachtenstil ---
|
|
KlausurCase(
|
|
case_id="kauf_02_style",
|
|
description=(
|
|
"Lieferant liefert nicht nach Nachfristablauf. Käuferin kauft "
|
|
"teurer woanders. Schadensersatz?"
|
|
),
|
|
statutes="§ 281 BGB, § 286 BGB",
|
|
correct_keywords=["Verzug", "Nachfrist", "Deckungsschaden"],
|
|
wrong_patterns=["kein Schadensersatz", "muss klagen"],
|
|
domain="style_gutachtenstil",
|
|
),
|
|
KlausurCase(
|
|
case_id="miet_03_style",
|
|
description=(
|
|
"78-jähriger, schwer kranker Mieter. Vermieter kündigt wegen "
|
|
"Eigenbedarfs. Sozialklausel greift?"
|
|
),
|
|
statutes="§ 574 BGB",
|
|
correct_keywords=["Sozialklausel", "Härte", "Widerspruch"],
|
|
wrong_patterns=["muss ausziehen", "Eigenbedarf schlägt"],
|
|
domain="style_gutachtenstil",
|
|
),
|
|
]
|
|
|
|
|
|
# ==============================================================================
|
|
# VALIDATION SCORING
|
|
# ==============================================================================
|
|
|
|
|
|
@dataclass
|
|
class CaseResult:
|
|
case_id: str
|
|
domain: str
|
|
output: str
|
|
keyword_hits: int
|
|
keyword_total: int
|
|
wrong_hit: bool
|
|
score: float # 0..1
|
|
|
|
|
|
def score_output(case: KlausurCase, output: str) -> CaseResult:
|
|
"""Scores a model output against correct keywords and wrong patterns."""
|
|
out_lower = output.lower()
|
|
keyword_hits = sum(1 for kw in case.correct_keywords if kw.lower() in out_lower)
|
|
wrong_hit = any(p.lower() in out_lower for p in case.wrong_patterns)
|
|
|
|
# Score: fraction of keywords hit, penalised by wrong patterns
|
|
raw = keyword_hits / max(1, len(case.correct_keywords))
|
|
score = raw * (0.0 if wrong_hit else 1.0)
|
|
return CaseResult(
|
|
case_id=case.case_id,
|
|
domain=case.domain,
|
|
output=output,
|
|
keyword_hits=keyword_hits,
|
|
keyword_total=len(case.correct_keywords),
|
|
wrong_hit=wrong_hit,
|
|
score=score,
|
|
)
|
|
|
|
|
|
def aggregate_domain_scores(results: List[CaseResult]) -> Dict[str, float]:
|
|
"""Average score per domain."""
|
|
domain_scores: Dict[str, List[float]] = {}
|
|
for r in results:
|
|
domain_scores.setdefault(r.domain, []).append(r.score)
|
|
return {d: sum(vs) / len(vs) for d, vs in domain_scores.items()}
|
|
|
|
|
|
# ==============================================================================
|
|
# MODEL INFERENCE
|
|
# ==============================================================================
|
|
|
|
|
|
def run_inference(
|
|
model: nn.Module,
|
|
tokenizer: Any,
|
|
case: KlausurCase,
|
|
device: str,
|
|
max_new_tokens: int = 60,
|
|
) -> str:
|
|
"""Generates model output for a Klausur case."""
|
|
prompt = (
|
|
f"Fall: {case.description}\n"
|
|
f"Normen: {case.statutes}\n"
|
|
"Gutachten (Subsumtion):\n"
|
|
)
|
|
enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
|
|
input_ids = enc["input_ids"].to(device)
|
|
|
|
model_any: Any = model
|
|
with torch.no_grad():
|
|
out = model_any.generate(
|
|
input_ids,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=False,
|
|
temperature=1.0,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
)
|
|
generated = out[0][input_ids.shape[1] :]
|
|
return str(tokenizer.decode(generated, skip_special_tokens=True))
|
|
|
|
|
|
# ==============================================================================
|
|
# GATE WARM-UP (use Kaggle pre-trained weights if available)
|
|
# ==============================================================================
|
|
|
|
|
|
def load_or_train_gate(
|
|
adapter_library: List[ExpertAdapter],
|
|
in_features: int,
|
|
gate_weights_path: Optional[str],
|
|
gate_steps: int = 100,
|
|
device: str = "cpu",
|
|
) -> LearnableGate:
|
|
"""Loads pre-trained gate or trains from scratch on domain supervision."""
|
|
num_adapters = len(adapter_library)
|
|
gate = LearnableGate(in_features, num_adapters).to(device)
|
|
|
|
if gate_weights_path and os.path.exists(gate_weights_path):
|
|
state = torch.load(gate_weights_path, map_location=device)
|
|
# Handle size mismatch (library grew since pre-training)
|
|
if list(state.values())[0].shape[0] == num_adapters or True:
|
|
try:
|
|
gate.load_state_dict(state, strict=False)
|
|
print(f" [Gate] Loaded pre-trained weights from {gate_weights_path}")
|
|
except RuntimeError as e:
|
|
print(f" [Gate] Shape mismatch ({e}), training from scratch...")
|
|
else:
|
|
return gate # type: ignore[no-any-return, unused-ignore]
|
|
|
|
# Train from scratch with domain supervision
|
|
import torch.nn.functional as F
|
|
|
|
adapter_targets = []
|
|
for adapter in adapter_library:
|
|
if "statute_recall" in adapter.domain_tags:
|
|
adapter_targets.append(0)
|
|
elif "logic_reasoning" in adapter.domain_tags:
|
|
adapter_targets.append(1)
|
|
else:
|
|
adapter_targets.append(2)
|
|
|
|
gate_opt = torch.optim.AdamW(gate.parameters(), lr=1e-3)
|
|
print(f" [Gate] Training gate for {num_adapters} adapters, {gate_steps} steps...")
|
|
|
|
for step in range(gate_steps):
|
|
x = torch.randn(8, in_features, device=device)
|
|
dom = adapter_targets[step % num_adapters]
|
|
x[:, dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0
|
|
|
|
target = torch.zeros(8, num_adapters, device=device)
|
|
for i, t in enumerate(adapter_targets):
|
|
if t == dom:
|
|
target[:, i] = 1.0 / max(
|
|
1, sum(1 for tt in adapter_targets if tt == dom)
|
|
)
|
|
|
|
gate_opt.zero_grad()
|
|
logits = gate(x)
|
|
loss = F.kl_div(F.log_softmax(logits, dim=-1), target, reduction="batchmean")
|
|
loss.backward() # type: ignore[no-untyped-call, unused-ignore]
|
|
gate_opt.step()
|
|
|
|
if step % 25 == 0:
|
|
print(f" Gate step {step:3d}/{gate_steps} | loss={loss.item():.4f}")
|
|
|
|
result: LearnableGate = gate
|
|
return result
|
|
|
|
|
|
# ==============================================================================
|
|
# MAIN VALIDATION PIPELINE
|
|
# ==============================================================================
|
|
|
|
|
|
def run_phase5(
|
|
model_name: str,
|
|
library_paths: List[str],
|
|
gate_weights_path: Optional[str],
|
|
output_dir: str,
|
|
max_new_tokens: int = 60,
|
|
gate_steps: int = 100,
|
|
) -> Dict[str, Any]:
|
|
"""Full Phase 5 validation pipeline."""
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # type: ignore[import-untyped, unused-ignore]
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
print("\n" + "=" * 70)
|
|
print(" PHASE 5: KLAUSUREN VALIDATION WITH MOE ROUTING")
|
|
print(f" Model: {model_name} | Device: {device}")
|
|
print("=" * 70)
|
|
|
|
# ── 1. Load model ────────────────────────────────────────────────────────
|
|
print("\n[1/6] Loading model...")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
|
model.eval()
|
|
|
|
config = AutoConfig.from_pretrained(model_name)
|
|
in_features: int = getattr(config, "hidden_size", 512)
|
|
|
|
# ── 2. Load/aggregate adapter library ────────────────────────────────────
|
|
print("\n[2/6] Aggregating adapter libraries...")
|
|
agg = AdapterLibraryAggregator(
|
|
max_adapters_per_domain=10, min_explained_variance=0.85
|
|
)
|
|
merge_stats = agg.load_and_merge(library_paths)
|
|
agg.rebalance()
|
|
adapter_library = agg.merged
|
|
|
|
if not adapter_library:
|
|
# Cold-start: extract adapters from current model state
|
|
print(" No pre-built library found. Extracting from model (cold-start)...")
|
|
extractor = ParasiticQLoRAExtractor(
|
|
QLoRAConfig(
|
|
min_rank=8,
|
|
max_rank=32,
|
|
explained_variance_threshold=0.90,
|
|
extraction_interval=1,
|
|
interesting_point_detection=False,
|
|
)
|
|
)
|
|
extractor.snapshot_base(model)
|
|
cold_adapter = extractor.extract_adapters(
|
|
model, 0, {"loss": 0.0, "optimizer": "baseline"}
|
|
)
|
|
aligner = ExpertManifoldAligner(model)
|
|
aligner.tag_adapter(cold_adapter)
|
|
adapter_library = [cold_adapter]
|
|
|
|
coverage = agg.coverage_report()
|
|
print(
|
|
f" Library: {coverage['total_adapters']} adapters | "
|
|
f"coverage={coverage['domain_coverage']}"
|
|
)
|
|
|
|
# ── 3. Load/train MoE gate ────────────────────────────────────────────────
|
|
print("\n[3/6] Initialising MoE gate...")
|
|
gate = load_or_train_gate(
|
|
adapter_library,
|
|
in_features,
|
|
gate_weights_path,
|
|
gate_steps=gate_steps,
|
|
device=device,
|
|
)
|
|
|
|
# ── 4. Activate MoE routing ───────────────────────────────────────────────
|
|
print("\n[4/6] Activating MoE routing via forward hooks...")
|
|
router = ExpertAdapterRouter(model, adapter_library, in_features=in_features)
|
|
# Inject pre-trained gate weights
|
|
router.gate.load_state_dict(gate.state_dict(), strict=False)
|
|
router.register_hooks()
|
|
print(f" Hooks registered on {len(router.hooks)} layers")
|
|
|
|
# ── 5. Baseline inference (routing OFF) ───────────────────────────────────
|
|
print("\n[5/6] Running validation (baseline vs MoE)...")
|
|
router.unregister_hooks()
|
|
|
|
baseline_results: List[CaseResult] = []
|
|
for case in VALIDATION_SUITE:
|
|
output = run_inference(model, tokenizer, case, device, max_new_tokens)
|
|
result = score_output(case, output)
|
|
baseline_results.append(result)
|
|
print(
|
|
f" [BASELINE] {case.case_id:<25} score={result.score:.2f} "
|
|
f"({result.keyword_hits}/{result.keyword_total} kw, "
|
|
f"wrong={'YES' if result.wrong_hit else 'no'})"
|
|
)
|
|
|
|
baseline_domain = aggregate_domain_scores(baseline_results)
|
|
|
|
# MoE routing ON
|
|
router.register_hooks()
|
|
|
|
moe_results: List[CaseResult] = []
|
|
for case in VALIDATION_SUITE:
|
|
output = run_inference(model, tokenizer, case, device, max_new_tokens)
|
|
result = score_output(case, output)
|
|
moe_results.append(result)
|
|
print(
|
|
f" [MoE] {case.case_id:<25} score={result.score:.2f} "
|
|
f"({result.keyword_hits}/{result.keyword_total} kw, "
|
|
f"wrong={'YES' if result.wrong_hit else 'no'})"
|
|
)
|
|
|
|
router.unregister_hooks()
|
|
moe_domain = aggregate_domain_scores(moe_results)
|
|
|
|
# ── 6. Report & telemetry ─────────────────────────────────────────────────
|
|
print("\n[6/6] Compiling results...")
|
|
|
|
baseline_avg = sum(r.score for r in baseline_results) / len(baseline_results)
|
|
moe_avg = sum(r.score for r in moe_results) / len(moe_results)
|
|
delta = moe_avg - baseline_avg
|
|
|
|
print("\n" + "=" * 70)
|
|
print(" RESULTS")
|
|
print("=" * 70)
|
|
print(f" {'Domain':<25} {'Baseline':>10} {'MoE':>10} {'Delta':>10}")
|
|
print(" " + "-" * 55)
|
|
for domain in ["statute_recall", "logic_reasoning", "style_gutachtenstil"]:
|
|
b = baseline_domain.get(domain, 0.0)
|
|
m = moe_domain.get(domain, 0.0)
|
|
d = m - b
|
|
marker = (
|
|
" <-- best"
|
|
if d
|
|
== max(
|
|
moe_domain.get(dd, 0.0) - baseline_domain.get(dd, 0.0)
|
|
for dd in moe_domain
|
|
)
|
|
else ""
|
|
)
|
|
print(f" {domain:<25} {b:>10.3f} {m:>10.3f} {d:>+10.3f}{marker}")
|
|
print(" " + "-" * 55)
|
|
print(f" {'OVERALL':<25} {baseline_avg:>10.3f} {moe_avg:>10.3f} {delta:>+10.3f}")
|
|
print("=" * 70)
|
|
|
|
# Save merged library
|
|
merged_lib_path = os.path.join(output_dir, "phase5_merged_library.pt")
|
|
agg.save(merged_lib_path)
|
|
|
|
# Build result dict
|
|
report: Dict[str, Any] = {
|
|
"model": model_name,
|
|
"device": device,
|
|
"adapter_count": len(adapter_library),
|
|
"merge_stats": merge_stats,
|
|
"library_coverage": coverage,
|
|
"baseline": {
|
|
"per_case": [
|
|
{
|
|
"case_id": r.case_id,
|
|
"score": r.score,
|
|
"keyword_hits": r.keyword_hits,
|
|
"wrong_hit": r.wrong_hit,
|
|
}
|
|
for r in baseline_results
|
|
],
|
|
"per_domain": baseline_domain,
|
|
"avg": baseline_avg,
|
|
},
|
|
"moe": {
|
|
"per_case": [
|
|
{
|
|
"case_id": r.case_id,
|
|
"score": r.score,
|
|
"keyword_hits": r.keyword_hits,
|
|
"wrong_hit": r.wrong_hit,
|
|
}
|
|
for r in moe_results
|
|
],
|
|
"per_domain": moe_domain,
|
|
"avg": moe_avg,
|
|
},
|
|
"delta_avg": delta,
|
|
"merged_library_path": merged_lib_path,
|
|
}
|
|
|
|
report_path = os.path.join(output_dir, "phase5_results.json")
|
|
with open(report_path, "w", encoding="utf-8") as fh:
|
|
json.dump(report, fh, indent=2, ensure_ascii=False)
|
|
print(f"\n Report saved: {report_path}")
|
|
|
|
# Push telemetry
|
|
if _TELEMETRY_AVAILABLE:
|
|
telemetry_entries = [
|
|
("INFO", "phase5_baseline_avg", f"{baseline_avg:.4f}"),
|
|
("INFO", "phase5_moe_avg", f"{moe_avg:.4f}"),
|
|
("INFO", "phase5_delta", f"{delta:.4f}"),
|
|
("INFO", "phase5_adapters", str(len(adapter_library))),
|
|
]
|
|
for domain, score in moe_domain.items():
|
|
telemetry_entries.append(("INFO", f"phase5_moe_{domain}", f"{score:.4f}"))
|
|
push_to_surrealdb(telemetry_entries)
|
|
push_to_mariadb(telemetry_entries)
|
|
|
|
return report
|
|
|
|
|
|
# ==============================================================================
|
|
# CLI
|
|
# ==============================================================================
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Phase 5: Klausuren Validation with MoE"
|
|
)
|
|
parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m")
|
|
parser.add_argument(
|
|
"--library",
|
|
nargs="*",
|
|
default=[],
|
|
help="Paths to .pt adapter library files (can specify multiple)",
|
|
)
|
|
parser.add_argument(
|
|
"--gate-weights",
|
|
type=str,
|
|
default=None,
|
|
help="Path to pre-trained gate weights .pt (from kaggle_prep step 3)",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default="phase5_output",
|
|
help="Directory for output files",
|
|
)
|
|
parser.add_argument("--max-new-tokens", type=int, default=60)
|
|
parser.add_argument("--gate-steps", type=int, default=100)
|
|
args = parser.parse_args()
|
|
|
|
# Auto-discover adapter libraries in cwd if none specified
|
|
library_paths: List[str] = list(args.library)
|
|
if not library_paths:
|
|
for fname in os.listdir("."):
|
|
if fname.startswith("parasitic_adapters_") and fname.endswith(".pt"):
|
|
library_paths.append(fname)
|
|
if library_paths:
|
|
print(f" Auto-discovered {len(library_paths)} libraries: {library_paths}")
|
|
|
|
t0 = time.perf_counter()
|
|
report = run_phase5(
|
|
model_name=args.model,
|
|
library_paths=library_paths,
|
|
gate_weights_path=args.gate_weights,
|
|
output_dir=args.output_dir,
|
|
max_new_tokens=args.max_new_tokens,
|
|
gate_steps=args.gate_steps,
|
|
)
|
|
elapsed = time.perf_counter() - t0
|
|
|
|
print(f"\n Phase 5 complete in {elapsed:.1f}s")
|
|
print(f" MoE delta vs baseline: {report['delta_avg']:+.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|