From 1d358fa5ad83398f28040fec52f4206ad0785319 Mon Sep 17 00:00:00 2001 From: AI-anonymous Date: Thu, 21 May 2026 05:47:36 +0200 Subject: [PATCH] feat: kaggle preparation pipeline - 5-step RunPod cost optimization --- kaggle_fces_prep.ipynb | 143 +++++++ kaggle_prep.py | 875 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1018 insertions(+) create mode 100644 kaggle_fces_prep.ipynb create mode 100644 kaggle_prep.py diff --git a/kaggle_fces_prep.ipynb b/kaggle_fces_prep.ipynb new file mode 100644 index 0000000..4b87759 --- /dev/null +++ b/kaggle_fces_prep.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FCES-native Kaggle Preparation Pipeline\n", + "\n", + "Führt alle 5 Vorbereitungsschritte durch:\n", + "1. Hyperparameter Sweep\n", + "2. Adapter Library aufbauen\n", + "3. MoE Gate vortrainieren\n", + "4. FCES Warm-Start Checkpoint\n", + "5. Klausuren-Dataset tokenisieren\n", + "\n", + "Alle Artefakte landen in `/kaggle/working/output/` und können als Dataset veröffentlicht werden." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── Cell 1: Environment Setup ──────────────────────────────────────────────\n", + "import subprocess, sys\n", + "\n", + "# Clone repo\n", + "subprocess.run(['git', 'clone', '--depth=1',\n", + " 'https://git.zky.de/sven/FCES-native', '/kaggle/working/FCES-native'],\n", + " check=True)\n", + "\n", + "# Install Python deps\n", + "subprocess.run([sys.executable, '-m', 'pip', 'install', '-q',\n", + " 'transformers', 'dspy-ai', 'pydantic', 'structlog'],\n", + " check=True)\n", + "\n", + "# Compile C++ extension\n", + "result = subprocess.run(\n", + " [sys.executable, 'python/setup.py', 'build_ext', '--inplace'],\n", + " cwd='/kaggle/working/FCES-native',\n", + " capture_output=True, text=True\n", + ")\n", + "if result.returncode != 0:\n", + " print('[WARN] fces_native compile failed — will use AdamW fallback for Step 4')\n", + " print(result.stderr[-500:])\n", + "else:\n", + " print('[OK] fces_native compiled successfully')\n", + "\n", + "print('Setup complete!')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── Cell 2: Run Full Pipeline ──────────────────────────────────────────────\n", + "# Adjust model and steps for your Kaggle GPU quota:\n", + "# T4 (15h/week): pythia-70m, sweep=10, adapters=30, warm=20\n", + "# P100 (30h/week): pythia-160m, sweep=15, adapters=50, warm=30\n", + "# L4 (30h/week): pythia-410m, sweep=20, adapters=80, warm=50\n", + "\n", + "MODEL = 'EleutherAI/pythia-160m' # ← erhöhe für bessere Adapters\n", + "SWEEP_STEPS = 15\n", + "ADAPTER_STEPS = 50\n", + "WARM_STEPS = 30\n", + "SKIP = '' # z.B. '1,3' um Steps zu überspringen\n", + "\n", + "subprocess.run(\n", + " [sys.executable, 'kaggle_prep.py',\n", + " '--model', MODEL,\n", + " '--sweep-steps', str(SWEEP_STEPS),\n", + " '--adapter-steps', str(ADAPTER_STEPS),\n", + " '--warm-steps', str(WARM_STEPS),\n", + " '--skip-steps', SKIP],\n", + " cwd='/kaggle/working/FCES-native',\n", + " check=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── Cell 3: Inspect Artefakte ──────────────────────────────────────────────\n", + "import json, os\n", + "\n", + "out = '/kaggle/working/FCES-native/kaggle_output'\n", + "summary_path = os.path.join(out, 'prep_summary.json')\n", + "\n", + "with open(summary_path) as f:\n", + " summary = json.load(f)\n", + "\n", + "print(f\"Total time : {summary['total_wall_time_s']}s\")\n", + "print(f\"Total size : {summary['total_size_mb']} MB\")\n", + "print(f\"Artifacts : {len(summary['output_files'])} files\")\n", + "print()\n", + "print('RunPod hint:')\n", + "print(summary['runpod_savings_hint'])\n", + "\n", + "# Show best hyperparams\n", + "cfg_path = os.path.join(out, 'step1_best_config.json')\n", + "if os.path.exists(cfg_path):\n", + " with open(cfg_path) as f:\n", + " cfg = json.load(f)\n", + " print(f\"\\nBest Hparams: {cfg['best']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── Cell 4: Publish als Kaggle Dataset ────────────────────────────────────\n", + "# (Optional — nur wenn Kaggle API konfiguriert)\n", + "# import kaggle\n", + "# kaggle.api.dataset_create_new(\n", + "# folder='/kaggle/working/output',\n", + "# dir_mode='zip',\n", + "# convert_to_csv=False,\n", + "# )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/kaggle_prep.py b/kaggle_prep.py new file mode 100644 index 0000000..bebe7fe --- /dev/null +++ b/kaggle_prep.py @@ -0,0 +1,875 @@ +# -*- coding: utf-8 -*- +"""Kaggle Preparation Pipeline for FCES-native. + +Führt alle 5 Vorbereitungsschritte auf Kaggle T4/P100 durch, bevor +teure RunPod A100-Zeit genutzt wird. Alle Artefakte werden in /kaggle/working/ +gespeichert und können als Kaggle Dataset veröffentlicht werden. + +Ausführung auf Kaggle: + !git clone https://git.zky.de/sven/FCES-native + !cd FCES-native && pip install -r requirements.txt + !cd FCES-native && python python/setup.py build_ext --inplace + !cd FCES-native && python kaggle_prep.py + +Produzierte Artefakte: + /output/step1_best_config.json ← Beste Hyperparameter + /output/step2_adapter_library.pt ← Pre-built Expert Adapter Library + /output/step3_gate_weights.pt ← Vortrainierte MoE Gate Weights + /output/step4_warm_checkpoint/ ← FCES Warm-Start Checkpoint + /output/step5_dataset.pt ← Tokenisiertes ORPO-Dataset + /output/prep_summary.json ← Zusammenfassung aller Schritte +""" + +from __future__ import annotations + +import json +import os +import sys +import time +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +# Ensure python/ is importable +_root = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(_root, "python")) + +from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402 +from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402 +from adapter_moe_router import LearnableGate # noqa: E402 + +OUTPUT_DIR = "/kaggle/working/output" +if not os.path.exists("/kaggle"): + OUTPUT_DIR = os.path.join(_root, "kaggle_output") +os.makedirs(OUTPUT_DIR, exist_ok=True) + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +print(f"[Kaggle Prep] Device: {DEVICE} | Output: {OUTPUT_DIR}") + + +# ============================================================================== +# LEGAL KLAUSUREN CASES (expanded for Kaggle prep) +# ============================================================================== + +KLAUSUREN_CASES: List[Dict[str, str]] = [ + # --- Mietrecht --- + { + "id": "miet_01", + "description": ( + "Mieter M besichtigt im Januar eine Wohnung des Vermieters V. " + "Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. " + "M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. " + "Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. " + "V verlangt volle Miete." + ), + "statutes": "§ 535 Abs. 2 BGB, § 536 Abs. 1 BGB, § 536b BGB", + "correct": ( + "M hat gemäß § 536b BGB das Minderungsrecht verloren, da er bei " + "Vertragsschluss positive Kenntnis vom Schimmelfleck hatte und " + "vorbehaltlos unterschrieben hat. V kann volle Miete verlangen." + ), + "flawed": ( + "M kann die Miete mindern, weil Schimmel ein Sachmangel ist. " + "§ 536b BGB ist irrelevant, da V eine Zusage gemacht hat." + ), + "domain": "statute_recall", + }, + { + "id": "miet_02", + "description": ( + "Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft " + "ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht " + "erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den " + "Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich." + ), + "statutes": "§ 536a Abs. 2 Nr. 2 BGB, § 535 BGB, § 683 BGB", + "correct": ( + "F hat nach § 536a Abs. 2 Nr. 2 BGB das Selbstbeseitigungsrecht bei " + "Gefahr im Verzug. V war nicht erreichbar, der Schaden drohend. " + "F kann 300 € Erstattung verlangen." + ), + "flawed": ( + "F hätte länger auf V warten müssen. Das Selbstbeseitigungsrecht " + "besteht nur bei völliger Unerreichbarkeit über mehrere Tage." + ), + "domain": "logic_reasoning", + }, + { + "id": "miet_03", + "description": ( + "Vermieter V kündigt dem Mieter M ordentlich wegen Eigenbedarfs (§ 573 BGB). " + "M ist 78 Jahre alt und schwer krank. Er beruft sich auf § 574 BGB " + "(Sozialklausel). V bestreitet die Härte." + ), + "statutes": "§ 573 Abs. 2 Nr. 2 BGB, § 574 BGB, § 574a BGB", + "correct": ( + "Die Kündigung ist formell wirksam (Eigenbedarf). Jedoch kann M gemäß " + "§ 574 BGB Widerspruch einlegen. Das Gericht wägt ab: Alter 78, schwere " + "Krankheit begründen in der Regel eine unzumutbare Härte. Kündigung " + "wird voraussichtlich ausgesetzt oder Fortsetzung angeordnet." + ), + "flawed": ( + "Eigenbedarf schlägt immer die Sozialklausel. M muss ausziehen. " + "§ 574 BGB gilt nur bei Vermietern ohne eigentlichen Bedarf." + ), + "domain": "logic_reasoning", + }, + # --- Kaufrecht --- + { + "id": "kauf_01", + "description": ( + "K kauft bei V einen Gebrauchtwagen für 8.000 €. Nach 3 Monaten " + "stellt sich heraus, dass der Kilometerzähler manipuliert war. " + "Tatsächlich 180.000 km statt 80.000 km. V beruft sich auf " + "'gekauft wie gesehen' Klausel." + ), + "statutes": "§ 434 BGB, § 437 BGB, § 444 BGB, § 476 BGB", + "correct": ( + "Die 'gekauft wie gesehen' Klausel schließt Gewährleistung nicht aus, " + "wenn arglistige Täuschung vorliegt (§ 444 BGB). Kilometerzähler-" + "manipulation ist arglistige Täuschung → Klausel unwirksam. " + "K kann Rücktritt (§ 437 Nr. 2 BGB) oder Schadensersatz verlangen." + ), + "flawed": ( + "'Gekauft wie gesehen' schließt alle Mängel aus. K hatte Gelegenheit " + "zur Besichtigung und muss das Fahrzeug so akzeptieren wie es ist." + ), + "domain": "statute_recall", + }, + { + "id": "kauf_02", + "description": ( + "Unternehmerin U bestellt bei Lieferant L 500 Stühle für ein Büro. " + "Lieferdatum: 15. März. Am 20. März noch keine Lieferung. U setzt " + "Nachfrist bis 25. März. L liefert nicht. U kauft Stühle teurer " + "woanders und fordert Mehrkosten von 3.000 € von L." + ), + "statutes": "§ 280 Abs. 1, 3 BGB, § 281 BGB, § 323 BGB, § 286 BGB", + "correct": ( + "L ist durch Überschreiten des Liefertermins in Verzug (§ 286 BGB). " + "Nach erfolglosem Ablauf der Nachfrist kann U nach § 281 BGB " + "Schadensersatz statt der Leistung verlangen. Die 3.000 € " + "Mehrkosten sind als Deckungsschaden erstattungsfähig." + ), + "flawed": ( + "U muss zunächst Klage erheben und kann erst nach rechtskräftigem Urteil " + "Ersatz verlangen. Die Nachfrist allein reicht nicht aus." + ), + "domain": "style_gutachtenstil", + }, + # --- Deliktsrecht --- + { + "id": "delikt_01", + "description": ( + "A fährt nachts betrunken (1,5 Promille) mit seinem Auto und fährt " + "das geparkte Auto des B zu. Schaden 5.000 €. A hat keine " + "Kfz-Haftpflichtversicherung. B verlangt Schadensersatz." + ), + "statutes": "§ 823 Abs. 1 BGB, § 827 BGB, § 828 BGB, § 7 StVG", + "correct": ( + "A haftet nach § 823 Abs. 1 BGB (Eigentumsverletzung). Trunkenheit " + "schließt die Deliktsfähigkeit nicht aus, da sie selbst herbeigeführt " + "wurde (actio libera in causa). Zudem haftet A nach § 7 StVG als " + "Halter/Fahrer. Vollständiger Schadensersatz 5.000 €." + ), + "flawed": ( + "A ist wegen Trunkenheit schuldunfähig (§ 827 BGB) und haftet nicht. " + "B muss seinen Schaden selbst tragen oder die eigene Versicherung " + "bemühen." + ), + "domain": "logic_reasoning", + }, +] + + +# ORPO Preference Pairs: (preferred=correct, dispreferred=flawed) +def build_orpo_pairs( + cases: List[Dict[str, str]], tokenizer: Any, max_length: int = 128 +) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + """Baut ORPO-Trainingspaare aus den Klausuren-Cases.""" + dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] + for case in cases: + prompt = ( + f"Fall: {case['description']}\n" + f"Normen: {case['statutes']}\n" + "Analysiere im Gutachtenstil (FIRT):\n" + ) + text_win = prompt + case["correct"] + text_lose = prompt + case["flawed"] + + enc_win = tokenizer( + text_win, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + enc_lose = tokenizer( + text_lose, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + labels_win = enc_win["input_ids"].clone() + labels_win[enc_win["attention_mask"] == 0] = -100 + labels_lose = enc_lose["input_ids"].clone() + labels_lose[enc_lose["attention_mask"] == 0] = -100 + + dataset.append( + ( + enc_win["input_ids"], + labels_win, + enc_lose["input_ids"], + labels_lose, + ) + ) + return dataset + + +# ============================================================================== +# STEP 1: HYPERPARAMETER SWEEP +# ============================================================================== + + +@dataclass +class HParamConfig: + lr: float + population_size: int + rank_min: int + rank_max: int + explained_variance: float + + +def run_hparam_sweep( + model_name: str, + dataset: List[Tuple[torch.Tensor, ...]], + steps: int = 15, +) -> HParamConfig: + """Grid Search über kritische Hyperparameter auf Kaggle T4.""" + from transformers import AutoModelForCausalLM + + configs: List[HParamConfig] = [ + HParamConfig( + lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95 + ), + HParamConfig( + lr=5e-5, + population_size=16, + rank_min=16, + rank_max=64, + explained_variance=0.95, + ), + HParamConfig( + lr=2e-4, population_size=4, rank_min=4, rank_max=16, explained_variance=0.90 + ), + HParamConfig( + lr=1e-4, population_size=8, rank_min=8, rank_max=64, explained_variance=0.97 + ), + HParamConfig( + lr=8e-5, + population_size=12, + rank_min=8, + rank_max=32, + explained_variance=0.95, + ), + ] + + best_config: Optional[HParamConfig] = None + best_score = float("inf") + results = [] + + print("\n" + "=" * 60) + print("STEP 1: HYPERPARAMETER SWEEP") + print("=" * 60) + + for i, cfg in enumerate(configs): + print( + f"\n[Sweep {i+1}/{len(configs)}] lr={cfg.lr}, pop={cfg.population_size}, " + f"rank=[{cfg.rank_min},{cfg.rank_max}]" + ) + + model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr) + + extractor = ParasiticQLoRAExtractor( + QLoRAConfig( + min_rank=cfg.rank_min, + max_rank=cfg.rank_max, + explained_variance_threshold=cfg.explained_variance, + extraction_interval=steps, + interesting_point_detection=False, + ) + ) + extractor.snapshot_base(model) + + model.train() + losses = [] + t0 = time.perf_counter() + + for step in range(steps): + batch = dataset[step % len(dataset)] + input_win, labels_win, input_lose, labels_lose = [ + t.to(DEVICE) for t in batch + ] + + optimizer.zero_grad() + outputs = model(input_win, labels=labels_win) + loss = outputs.loss + loss.backward() + optimizer.step() + losses.append(float(loss.item())) + + wall_time = time.perf_counter() - t0 + + # Extract one adapter to measure quality + adapter = extractor.extract_adapters( + model, steps, {"loss": losses[-1], "optimizer": "AdamW"} + ) + avg_ev = adapter.avg_explained_variance + final_loss = losses[-1] + score = final_loss * (1.0 - avg_ev) # lower is better + + print( + f" -> final_loss={final_loss:.4f} | avg_ev={avg_ev:.3f} | " + f"score={score:.4f} | time={wall_time:.1f}s" + ) + results.append( + { + "config": asdict(cfg), + "score": score, + "final_loss": final_loss, + "avg_ev": avg_ev, + } + ) + + if score < best_score: + best_score = score + best_config = cfg + + del model + if DEVICE == "cuda": + torch.cuda.empty_cache() + + out_path = os.path.join(OUTPUT_DIR, "step1_best_config.json") + with open(out_path, "w", encoding="utf-8") as f: + json.dump({"best": asdict(best_config), "all_results": results}, f, indent=2) # type: ignore + + print(f"\n✅ Step 1 done. Best: {best_config} → saved to {out_path}") + return best_config # type: ignore + + +# ============================================================================== +# STEP 2: BUILD ADAPTER LIBRARY ON SMALL MODELS +# ============================================================================== + + +def build_adapter_library( + model_name: str, + best_cfg: HParamConfig, + dataset: List[Tuple[torch.Tensor, ...]], + steps: int = 50, +) -> List[ExpertAdapter]: + """Trainiert Pythia-70m/160m auf Kaggle und baut eine diverse Adapter Library auf.""" + from transformers import AutoModelForCausalLM + + print("\n" + "=" * 60) + print("STEP 2: BUILD ADAPTER LIBRARY") + print("=" * 60) + + model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) + aligner = ExpertManifoldAligner(model) + + extractor = ParasiticQLoRAExtractor( + QLoRAConfig( + min_rank=best_cfg.rank_min, + max_rank=best_cfg.rank_max, + explained_variance_threshold=best_cfg.explained_variance, + extraction_interval=max(5, steps // 10), # Extract ~10 checkpoints + interesting_point_detection=True, + ) + ) + extractor.snapshot_base(model) + + optimizer = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr) + model.train() + + print( + f"Training {model_name} for {steps} steps, extracting adapters every " + f"{extractor.config.extraction_interval} steps..." + ) + + for step in range(steps): + batch = dataset[step % len(dataset)] + input_win, labels_win, input_lose, labels_lose = [t.to(DEVICE) for t in batch] + + optimizer.zero_grad() + outputs = model(input_win, labels=labels_win) + loss = outputs.loss + loss.backward() + optimizer.step() + + aligner.track_step(model) + + if extractor.should_extract(step, float(loss.item())): + metrics: Dict[str, Any] = { + "loss": float(loss.item()), + "optimizer": "AdamW", + } + adapter = extractor.extract_adapters(model, step, metrics) + aligner.tag_adapter(adapter) + profile = aligner.profile_adapter(adapter) + print( + f" Step {step:3d} | loss={loss.item():.4f} | " + f"tags={adapter.domain_tags} | " + f"statute={profile['statute_recall']:.2f} " + f"logic={profile['logic_reasoning']:.2f} " + f"style={profile['style_gutachtenstil']:.2f}" + ) + + if step % 10 == 0: + print(f" Step {step}/{steps} | loss={loss.item():.4f}") + + # Save checkpoint for warm-start (Step 4 will also use this) + ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint") + os.makedirs(ckpt_dir, exist_ok=True) + model.save_pretrained(ckpt_dir) + print(f" Checkpoint saved to {ckpt_dir}") + + # Save adapter library + lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt") + extractor.save_library(lib_path) + print( + f"\n✅ Step 2 done. Library: {len(extractor.adapter_library)} adapters → {lib_path}" + ) + + del model + if DEVICE == "cuda": + torch.cuda.empty_cache() + + return list(extractor.adapter_library) + + +# ============================================================================== +# STEP 3: PRETRAIN MOE GATE +# ============================================================================== + + +def pretrain_moe_gate( + adapter_library: List[ExpertAdapter], + in_features: int = 512, + gate_steps: int = 200, +) -> None: + """Trainiert den LearnableGate MLP mit synthetischen Domain-Routing-Labels.""" + print("\n" + "=" * 60) + print("STEP 3: PRETRAIN MoE GATE") + print("=" * 60) + + num_adapters = len(adapter_library) + if num_adapters == 0: + print(" No adapters in library, skipping.") + return + + gate = LearnableGate(in_features, num_adapters).to(DEVICE) + gate_optimizer = torch.optim.AdamW(gate.parameters(), lr=1e-3) + + # Build synthetic supervision: for each adapter, its dominant domain + # determines what input pattern should route to it. + # We use one-hot targets based on domain tags. + adapter_targets = [] + for adapter in adapter_library: + # Find dominant domain + if "statute_recall" in adapter.domain_tags: + d = 0 + elif "logic_reasoning" in adapter.domain_tags: + d = 1 + elif "style_gutachtenstil" in adapter.domain_tags: + d = 2 + else: + d = 0 + adapter_targets.append(d) + + print(f" Training gate for {num_adapters} adapters, {gate_steps} steps...") + losses = [] + + for step in range(gate_steps): + # Synthetic input: each domain has a rough "signature" pattern + batch_size = 8 + domain_inputs = [] + domain_labels = [] + + for _ in range(batch_size): + # Pick a random adapter + adapter_idx = step % num_adapters + dom = adapter_targets[adapter_idx] + + # Domain-specific input pattern (synthetic) + x = torch.randn(in_features, device=DEVICE) + # Amplify features in the domain's "direction" + x[dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0 + + domain_inputs.append(x) + + # Target: one-hot over adapters with this adapter having highest weight + # We want gate to learn to output high score for adapters of this domain + matching_adapters = [i for i, t in enumerate(adapter_targets) if t == dom] + target = torch.zeros(num_adapters, device=DEVICE) + for idx in matching_adapters: + target[idx] = 1.0 / max(1, len(matching_adapters)) + domain_labels.append(target) + + x_batch = torch.stack(domain_inputs) # [B, in_features] + y_batch = torch.stack(domain_labels) # [B, num_adapters] + + gate_optimizer.zero_grad() + logits = gate(x_batch) # [B, num_adapters] + loss = F.kl_div( + F.log_softmax(logits, dim=-1), + y_batch, + reduction="batchmean", + ) + loss.backward() + gate_optimizer.step() + losses.append(float(loss.item())) + + if step % 50 == 0: + print(f" Gate step {step}/{gate_steps} | loss={loss.item():.4f}") + + gate_path = os.path.join(OUTPUT_DIR, "step3_gate_weights.pt") + torch.save(gate.state_dict(), gate_path) + final_loss = sum(losses[-20:]) / 20 + print(f"\n✅ Step 3 done. Gate loss={final_loss:.4f} → {gate_path}") + + +# ============================================================================== +# STEP 4: FCES WARM-START CHECKPOINT +# ============================================================================== + + +def build_fces_warm_start( + model_name: str, + best_cfg: HParamConfig, + dataset: List[Tuple[torch.Tensor, ...]], + warm_steps: int = 30, +) -> None: + """Führt FCES für warm_steps durch, speichert Checkpoint und Population-Metadaten. + + Da FCESOptimizer keine vollständige Population-Serialisierung hat (TODO in C++), + speichern wir: + 1. Model weights (bestes Checkpoint nach letztem Roll-Forward) + 2. Aktive Controller-Metadata: ID, fitness, step_count + 3. Optimale Hyperparameter für RunPod als JSON + """ + print("\n" + "=" * 60) + print("STEP 4: FCES WARM-START") + print("=" * 60) + + # Try to import fces_native; fall back to AdamW warm-start if not available + try: + import fces_native # noqa: F401 + + fces_available = True + except ImportError: + print(" [WARN] fces_native not compiled. Using AdamW warm-start instead.") + fces_available = False + + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) + model.train() + + controller_log: List[Dict[str, Any]] = [] + best_loss = float("inf") + best_step = 0 + + if fces_available: + import fces_native + + cfg = fces_native.FCESConfig() + cfg.lr = best_cfg.lr + cfg.population_size = best_cfg.population_size + cfg.total_steps = warm_steps + optimizer_fces = fces_native.FCESOptimizer(list(model.parameters()), cfg) + + print(f" Running FCES warm-start for {warm_steps} steps...") + for step in range(warm_steps): + batch = dataset[step % len(dataset)] + input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch] + + optimizer_fces.zero_grad() + outputs = model(input_win, labels=labels_win) + loss = outputs.loss + loss.backward() + optimizer_fces.step() + optimizer_fces.update_fitness(float(loss.item())) + + ctrl_id = optimizer_fces.get_active_controller_id() + ctrl_fit = optimizer_fces.get_active_controller_fitness() + controller_log.append( + { + "step": step, + "loss": float(loss.item()), + "controller_id": int(ctrl_id), + "controller_fitness": float(ctrl_fit), + } + ) + + if float(loss.item()) < best_loss: + best_loss = float(loss.item()) + best_step = step + + if step % 10 == 0: + print( + f" Step {step:3d} | loss={loss.item():.4f} | " + f"ctrl_id={ctrl_id} | fitness={ctrl_fit:.4f}" + ) + else: + # AdamW warm-start as fallback + optimizer_adam = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr) + for step in range(warm_steps): + batch = dataset[step % len(dataset)] + input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch] + optimizer_adam.zero_grad() + outputs = model(input_win, labels=labels_win) + loss = outputs.loss + loss.backward() + optimizer_adam.step() + if float(loss.item()) < best_loss: + best_loss = float(loss.item()) + best_step = step + if step % 10 == 0: + print(f" Step {step:3d} | loss={loss.item():.4f}") + + # Save warm checkpoint + ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint") + os.makedirs(ckpt_dir, exist_ok=True) + model.save_pretrained(ckpt_dir) + + # Save warm-start metadata + meta = { + "model_name": model_name, + "warm_steps": warm_steps, + "best_loss": best_loss, + "best_step": best_step, + "fces_available": fces_available, + "best_hparams": asdict(best_cfg), + "controller_log": controller_log, + "runpod_hint": ( + f"Load from {ckpt_dir}. FCES already explored {warm_steps} steps. " + f"Resume from step {warm_steps} with lr={best_cfg.lr}, " + f"population_size={best_cfg.population_size}." + ), + } + meta_path = os.path.join(OUTPUT_DIR, "step4_warm_metadata.json") + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2) + + print( + f"\n✅ Step 4 done. Warm checkpoint (best_loss={best_loss:.4f} @ step {best_step})" + ) + print(f" Checkpoint: {ckpt_dir}") + print(f" Metadata: {meta_path}") + + del model + if DEVICE == "cuda": + torch.cuda.empty_cache() + + +# ============================================================================== +# STEP 5: TOKENIZE DATASET + ORPO PAIRS +# ============================================================================== + + +def build_full_dataset(model_name: str) -> None: + """Tokenisiert alle Klausuren-Cases als ORPO-Pairs und speichert sie als .pt.""" + from transformers import AutoTokenizer + + print("\n" + "=" * 60) + print("STEP 5: TOKENIZE DATASET + BUILD ORPO PAIRS") + print("=" * 60) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Build at multiple sequence lengths for flexible training + datasets: Dict[str, Any] = {} + for max_len in [64, 128, 256]: + pairs = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=max_len) + datasets[f"max_len_{max_len}"] = [ + (iw.squeeze(0), lw.squeeze(0), il.squeeze(0), ll.squeeze(0)) + for (iw, lw, il, ll) in pairs + ] + print(f" max_len={max_len}: {len(pairs)} pairs built") + + # Save metadata + case_meta = [ + { + "id": c["id"], + "domain": c["domain"], + "statutes": c["statutes"], + "description_len": len(c["description"]), + } + for c in KLAUSUREN_CASES + ] + datasets["metadata"] = { + "model_name": model_name, + "num_cases": len(KLAUSUREN_CASES), + "cases": case_meta, + "domains": list({c["domain"] for c in KLAUSUREN_CASES}), + "vocab_size": tokenizer.vocab_size, + } + + out_path = os.path.join(OUTPUT_DIR, "step5_dataset.pt") + torch.save(datasets, out_path) + size_mb = os.path.getsize(out_path) / 1024 / 1024 + print(f"\n✅ Step 5 done. Dataset ({size_mb:.2f} MB) → {out_path}") + print(f" Domains: {datasets['metadata']['domains']}") + + +# ============================================================================== +# MAIN PIPELINE +# ============================================================================== + + +def main() -> None: + import argparse + from transformers import AutoTokenizer + + parser = argparse.ArgumentParser(description="Kaggle Preparation Pipeline") + parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m") + parser.add_argument( + "--sweep-steps", type=int, default=15, help="Steps per hyperparameter sweep run" + ) + parser.add_argument( + "--adapter-steps", type=int, default=50, help="Steps for adapter library build" + ) + parser.add_argument( + "--warm-steps", type=int, default=30, help="Steps for FCES warm-start" + ) + parser.add_argument( + "--skip-steps", + type=str, + default="", + help="Comma-separated step numbers to skip, e.g. '1,3'", + ) + args = parser.parse_args() + + skip = {int(s) for s in args.skip_steps.split(",") if s.strip()} + t_total = time.perf_counter() + + print("\n" + "=" * 70) + print(" KAGGLE PREPARATION PIPELINE — FCES-native") + print(f" Model: {args.model} | Device: {DEVICE}") + print("=" * 70) + + # Build shared tokenized dataset (used across all steps) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + shared_dataset = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=64) + + # ── Step 1: Hyperparameter Sweep ───────────────────────────────────────── + if 1 not in skip: + best_cfg = run_hparam_sweep(args.model, shared_dataset, steps=args.sweep_steps) + else: + print("\n[SKIP] Step 1: Using default config") + best_cfg = HParamConfig( + lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95 + ) + + # ── Step 2: Adapter Library ─────────────────────────────────────────────── + if 2 not in skip: + adapter_library = build_adapter_library( + args.model, best_cfg, shared_dataset, steps=args.adapter_steps + ) + else: + print("\n[SKIP] Step 2: Loading existing library") + lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt") + adapter_library = ( + ParasiticQLoRAExtractor.load_library(lib_path) + if os.path.exists(lib_path) + else [] + ) + + # ── Step 3: MoE Gate Pre-training ──────────────────────────────────────── + if 3 not in skip: + # Detect in_features from model config + try: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(args.model) + in_features = getattr(config, "hidden_size", 512) + except Exception: + in_features = 512 + pretrain_moe_gate(adapter_library, in_features=in_features, gate_steps=200) + else: + print("\n[SKIP] Step 3: Gate pre-training skipped") + + # ── Step 4: FCES Warm-Start ─────────────────────────────────────────────── + if 4 not in skip: + build_fces_warm_start( + args.model, best_cfg, shared_dataset, warm_steps=args.warm_steps + ) + else: + print("\n[SKIP] Step 4: Warm-start skipped") + + # ── Step 5: Full Dataset Tokenization ───────────────────────────────────── + if 5 not in skip: + build_full_dataset(args.model) + else: + print("\n[SKIP] Step 5: Dataset tokenization skipped") + + # ── Summary ─────────────────────────────────────────────────────────────── + total_time = time.perf_counter() - t_total + output_files = [f for f in os.listdir(OUTPUT_DIR) if not f.startswith(".")] + total_size_mb = ( + sum( + os.path.getsize(os.path.join(OUTPUT_DIR, f)) + for f in output_files + if os.path.isfile(os.path.join(OUTPUT_DIR, f)) + ) + / 1024 + / 1024 + ) + + summary = { + "model": args.model, + "device": DEVICE, + "total_wall_time_s": round(total_time, 1), + "output_files": output_files, + "total_size_mb": round(total_size_mb, 2), + "runpod_savings_hint": ( + "Upload /output/ as Kaggle Dataset. " + "RunPod: load step1_best_config.json for hparams, " + "step2_adapter_library.pt for warm adapter library, " + "step3_gate_weights.pt for MoE router, " + "step4_warm_checkpoint/ as starting model weights, " + "step5_dataset.pt for pre-tokenized training data. " + "Estimated RunPod time savings: 60-80% vs cold start." + ), + } + + summary_path = os.path.join(OUTPUT_DIR, "prep_summary.json") + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + + print("\n" + "=" * 70) + print(" KAGGLE PREP COMPLETE") + print("=" * 70) + print(f" Total time : {total_time:.1f}s") + print(f" Artifacts : {len(output_files)} files, {total_size_mb:.1f} MB") + print(f" Output dir : {OUTPUT_DIR}") + print("\n RunPod Ladeanweisung:") + print(" kaggle datasets download sven/fces-prep -p /data/") + for fname in sorted(output_files): + size = os.path.getsize(os.path.join(OUTPUT_DIR, fname)) / 1024 + if os.path.isfile(os.path.join(OUTPUT_DIR, fname)): + print(f" /data/{fname:<35} ({size:.0f} KB)") + print("=" * 70) + + +if __name__ == "__main__": + main()