Files
FCES-native/kaggle_prep.py

876 lines
32 KiB
Python

# -*- coding: utf-8 -*-
"""Kaggle Preparation Pipeline for FCES-native.
Führt alle 5 Vorbereitungsschritte auf Kaggle T4/P100 durch, bevor
teure RunPod A100-Zeit genutzt wird. Alle Artefakte werden in /kaggle/working/
gespeichert und können als Kaggle Dataset veröffentlicht werden.
Ausführung auf Kaggle:
!git clone https://git.zky.de/sven/FCES-native
!cd FCES-native && pip install -r requirements.txt
!cd FCES-native && python python/setup.py build_ext --inplace
!cd FCES-native && python kaggle_prep.py
Produzierte Artefakte:
/output/step1_best_config.json ← Beste Hyperparameter
/output/step2_adapter_library.pt ← Pre-built Expert Adapter Library
/output/step3_gate_weights.pt ← Vortrainierte MoE Gate Weights
/output/step4_warm_checkpoint/ ← FCES Warm-Start Checkpoint
/output/step5_dataset.pt ← Tokenisiertes ORPO-Dataset
/output/prep_summary.json ← Zusammenfassung aller Schritte
"""
from __future__ import annotations
import json
import os
import sys
import time
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
# Ensure python/ is importable
_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(_root, "python"))
from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
from adapter_moe_router import LearnableGate # noqa: E402
OUTPUT_DIR = "/kaggle/working/output"
if not os.path.exists("/kaggle"):
OUTPUT_DIR = os.path.join(_root, "kaggle_output")
os.makedirs(OUTPUT_DIR, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[Kaggle Prep] Device: {DEVICE} | Output: {OUTPUT_DIR}")
# ==============================================================================
# LEGAL KLAUSUREN CASES (expanded for Kaggle prep)
# ==============================================================================
KLAUSUREN_CASES: List[Dict[str, str]] = [
# --- Mietrecht ---
{
"id": "miet_01",
"description": (
"Mieter M besichtigt im Januar eine Wohnung des Vermieters V. "
"Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. "
"M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. "
"Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. "
"V verlangt volle Miete."
),
"statutes": "§ 535 Abs. 2 BGB, § 536 Abs. 1 BGB, § 536b BGB",
"correct": (
"M hat gemäß § 536b BGB das Minderungsrecht verloren, da er bei "
"Vertragsschluss positive Kenntnis vom Schimmelfleck hatte und "
"vorbehaltlos unterschrieben hat. V kann volle Miete verlangen."
),
"flawed": (
"M kann die Miete mindern, weil Schimmel ein Sachmangel ist. "
"§ 536b BGB ist irrelevant, da V eine Zusage gemacht hat."
),
"domain": "statute_recall",
},
{
"id": "miet_02",
"description": (
"Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft "
"ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht "
"erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den "
"Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich."
),
"statutes": "§ 536a Abs. 2 Nr. 2 BGB, § 535 BGB, § 683 BGB",
"correct": (
"F hat nach § 536a Abs. 2 Nr. 2 BGB das Selbstbeseitigungsrecht bei "
"Gefahr im Verzug. V war nicht erreichbar, der Schaden drohend. "
"F kann 300 € Erstattung verlangen."
),
"flawed": (
"F hätte länger auf V warten müssen. Das Selbstbeseitigungsrecht "
"besteht nur bei völliger Unerreichbarkeit über mehrere Tage."
),
"domain": "logic_reasoning",
},
{
"id": "miet_03",
"description": (
"Vermieter V kündigt dem Mieter M ordentlich wegen Eigenbedarfs (§ 573 BGB). "
"M ist 78 Jahre alt und schwer krank. Er beruft sich auf § 574 BGB "
"(Sozialklausel). V bestreitet die Härte."
),
"statutes": "§ 573 Abs. 2 Nr. 2 BGB, § 574 BGB, § 574a BGB",
"correct": (
"Die Kündigung ist formell wirksam (Eigenbedarf). Jedoch kann M gemäß "
"§ 574 BGB Widerspruch einlegen. Das Gericht wägt ab: Alter 78, schwere "
"Krankheit begründen in der Regel eine unzumutbare Härte. Kündigung "
"wird voraussichtlich ausgesetzt oder Fortsetzung angeordnet."
),
"flawed": (
"Eigenbedarf schlägt immer die Sozialklausel. M muss ausziehen. "
"§ 574 BGB gilt nur bei Vermietern ohne eigentlichen Bedarf."
),
"domain": "logic_reasoning",
},
# --- Kaufrecht ---
{
"id": "kauf_01",
"description": (
"K kauft bei V einen Gebrauchtwagen für 8.000 €. Nach 3 Monaten "
"stellt sich heraus, dass der Kilometerzähler manipuliert war. "
"Tatsächlich 180.000 km statt 80.000 km. V beruft sich auf "
"'gekauft wie gesehen' Klausel."
),
"statutes": "§ 434 BGB, § 437 BGB, § 444 BGB, § 476 BGB",
"correct": (
"Die 'gekauft wie gesehen' Klausel schließt Gewährleistung nicht aus, "
"wenn arglistige Täuschung vorliegt (§ 444 BGB). Kilometerzähler-"
"manipulation ist arglistige Täuschung → Klausel unwirksam. "
"K kann Rücktritt (§ 437 Nr. 2 BGB) oder Schadensersatz verlangen."
),
"flawed": (
"'Gekauft wie gesehen' schließt alle Mängel aus. K hatte Gelegenheit "
"zur Besichtigung und muss das Fahrzeug so akzeptieren wie es ist."
),
"domain": "statute_recall",
},
{
"id": "kauf_02",
"description": (
"Unternehmerin U bestellt bei Lieferant L 500 Stühle für ein Büro. "
"Lieferdatum: 15. März. Am 20. März noch keine Lieferung. U setzt "
"Nachfrist bis 25. März. L liefert nicht. U kauft Stühle teurer "
"woanders und fordert Mehrkosten von 3.000 € von L."
),
"statutes": "§ 280 Abs. 1, 3 BGB, § 281 BGB, § 323 BGB, § 286 BGB",
"correct": (
"L ist durch Überschreiten des Liefertermins in Verzug (§ 286 BGB). "
"Nach erfolglosem Ablauf der Nachfrist kann U nach § 281 BGB "
"Schadensersatz statt der Leistung verlangen. Die 3.000 € "
"Mehrkosten sind als Deckungsschaden erstattungsfähig."
),
"flawed": (
"U muss zunächst Klage erheben und kann erst nach rechtskräftigem Urteil "
"Ersatz verlangen. Die Nachfrist allein reicht nicht aus."
),
"domain": "style_gutachtenstil",
},
# --- Deliktsrecht ---
{
"id": "delikt_01",
"description": (
"A fährt nachts betrunken (1,5 Promille) mit seinem Auto und fährt "
"das geparkte Auto des B zu. Schaden 5.000 €. A hat keine "
"Kfz-Haftpflichtversicherung. B verlangt Schadensersatz."
),
"statutes": "§ 823 Abs. 1 BGB, § 827 BGB, § 828 BGB, § 7 StVG",
"correct": (
"A haftet nach § 823 Abs. 1 BGB (Eigentumsverletzung). Trunkenheit "
"schließt die Deliktsfähigkeit nicht aus, da sie selbst herbeigeführt "
"wurde (actio libera in causa). Zudem haftet A nach § 7 StVG als "
"Halter/Fahrer. Vollständiger Schadensersatz 5.000 €."
),
"flawed": (
"A ist wegen Trunkenheit schuldunfähig (§ 827 BGB) und haftet nicht. "
"B muss seinen Schaden selbst tragen oder die eigene Versicherung "
"bemühen."
),
"domain": "logic_reasoning",
},
]
# ORPO Preference Pairs: (preferred=correct, dispreferred=flawed)
def build_orpo_pairs(
cases: List[Dict[str, str]], tokenizer: Any, max_length: int = 128
) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Baut ORPO-Trainingspaare aus den Klausuren-Cases."""
dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
for case in cases:
prompt = (
f"Fall: {case['description']}\n"
f"Normen: {case['statutes']}\n"
"Analysiere im Gutachtenstil (FIRT):\n"
)
text_win = prompt + case["correct"]
text_lose = prompt + case["flawed"]
enc_win = tokenizer(
text_win,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
enc_lose = tokenizer(
text_lose,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
labels_win = enc_win["input_ids"].clone()
labels_win[enc_win["attention_mask"] == 0] = -100
labels_lose = enc_lose["input_ids"].clone()
labels_lose[enc_lose["attention_mask"] == 0] = -100
dataset.append(
(
enc_win["input_ids"],
labels_win,
enc_lose["input_ids"],
labels_lose,
)
)
return dataset
# ==============================================================================
# STEP 1: HYPERPARAMETER SWEEP
# ==============================================================================
@dataclass
class HParamConfig:
lr: float
population_size: int
rank_min: int
rank_max: int
explained_variance: float
def run_hparam_sweep(
model_name: str,
dataset: List[Tuple[torch.Tensor, ...]],
steps: int = 15,
) -> HParamConfig:
"""Grid Search über kritische Hyperparameter auf Kaggle T4."""
from transformers import AutoModelForCausalLM
configs: List[HParamConfig] = [
HParamConfig(
lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95
),
HParamConfig(
lr=5e-5,
population_size=16,
rank_min=16,
rank_max=64,
explained_variance=0.95,
),
HParamConfig(
lr=2e-4, population_size=4, rank_min=4, rank_max=16, explained_variance=0.90
),
HParamConfig(
lr=1e-4, population_size=8, rank_min=8, rank_max=64, explained_variance=0.97
),
HParamConfig(
lr=8e-5,
population_size=12,
rank_min=8,
rank_max=32,
explained_variance=0.95,
),
]
best_config: Optional[HParamConfig] = None
best_score = float("inf")
results = []
print("\n" + "=" * 60)
print("STEP 1: HYPERPARAMETER SWEEP")
print("=" * 60)
for i, cfg in enumerate(configs):
print(
f"\n[Sweep {i+1}/{len(configs)}] lr={cfg.lr}, pop={cfg.population_size}, "
f"rank=[{cfg.rank_min},{cfg.rank_max}]"
)
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
extractor = ParasiticQLoRAExtractor(
QLoRAConfig(
min_rank=cfg.rank_min,
max_rank=cfg.rank_max,
explained_variance_threshold=cfg.explained_variance,
extraction_interval=steps,
interesting_point_detection=False,
)
)
extractor.snapshot_base(model)
model.train()
losses = []
t0 = time.perf_counter()
for step in range(steps):
batch = dataset[step % len(dataset)]
input_win, labels_win, input_lose, labels_lose = [
t.to(DEVICE) for t in batch
]
optimizer.zero_grad()
outputs = model(input_win, labels=labels_win)
loss = outputs.loss
loss.backward()
optimizer.step()
losses.append(float(loss.item()))
wall_time = time.perf_counter() - t0
# Extract one adapter to measure quality
adapter = extractor.extract_adapters(
model, steps, {"loss": losses[-1], "optimizer": "AdamW"}
)
avg_ev = adapter.avg_explained_variance
final_loss = losses[-1]
score = final_loss * (1.0 - avg_ev) # lower is better
print(
f" -> final_loss={final_loss:.4f} | avg_ev={avg_ev:.3f} | "
f"score={score:.4f} | time={wall_time:.1f}s"
)
results.append(
{
"config": asdict(cfg),
"score": score,
"final_loss": final_loss,
"avg_ev": avg_ev,
}
)
if score < best_score:
best_score = score
best_config = cfg
del model
if DEVICE == "cuda":
torch.cuda.empty_cache()
out_path = os.path.join(OUTPUT_DIR, "step1_best_config.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump({"best": asdict(best_config), "all_results": results}, f, indent=2) # type: ignore
print(f"\n✅ Step 1 done. Best: {best_config} → saved to {out_path}")
return best_config # type: ignore
# ==============================================================================
# STEP 2: BUILD ADAPTER LIBRARY ON SMALL MODELS
# ==============================================================================
def build_adapter_library(
model_name: str,
best_cfg: HParamConfig,
dataset: List[Tuple[torch.Tensor, ...]],
steps: int = 50,
) -> List[ExpertAdapter]:
"""Trainiert Pythia-70m/160m auf Kaggle und baut eine diverse Adapter Library auf."""
from transformers import AutoModelForCausalLM
print("\n" + "=" * 60)
print("STEP 2: BUILD ADAPTER LIBRARY")
print("=" * 60)
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
aligner = ExpertManifoldAligner(model)
extractor = ParasiticQLoRAExtractor(
QLoRAConfig(
min_rank=best_cfg.rank_min,
max_rank=best_cfg.rank_max,
explained_variance_threshold=best_cfg.explained_variance,
extraction_interval=max(5, steps // 10), # Extract ~10 checkpoints
interesting_point_detection=True,
)
)
extractor.snapshot_base(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr)
model.train()
print(
f"Training {model_name} for {steps} steps, extracting adapters every "
f"{extractor.config.extraction_interval} steps..."
)
for step in range(steps):
batch = dataset[step % len(dataset)]
input_win, labels_win, input_lose, labels_lose = [t.to(DEVICE) for t in batch]
optimizer.zero_grad()
outputs = model(input_win, labels=labels_win)
loss = outputs.loss
loss.backward()
optimizer.step()
aligner.track_step(model)
if extractor.should_extract(step, float(loss.item())):
metrics: Dict[str, Any] = {
"loss": float(loss.item()),
"optimizer": "AdamW",
}
adapter = extractor.extract_adapters(model, step, metrics)
aligner.tag_adapter(adapter)
profile = aligner.profile_adapter(adapter)
print(
f" Step {step:3d} | loss={loss.item():.4f} | "
f"tags={adapter.domain_tags} | "
f"statute={profile['statute_recall']:.2f} "
f"logic={profile['logic_reasoning']:.2f} "
f"style={profile['style_gutachtenstil']:.2f}"
)
if step % 10 == 0:
print(f" Step {step}/{steps} | loss={loss.item():.4f}")
# Save checkpoint for warm-start (Step 4 will also use this)
ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint")
os.makedirs(ckpt_dir, exist_ok=True)
model.save_pretrained(ckpt_dir)
print(f" Checkpoint saved to {ckpt_dir}")
# Save adapter library
lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt")
extractor.save_library(lib_path)
print(
f"\n✅ Step 2 done. Library: {len(extractor.adapter_library)} adapters → {lib_path}"
)
del model
if DEVICE == "cuda":
torch.cuda.empty_cache()
return list(extractor.adapter_library)
# ==============================================================================
# STEP 3: PRETRAIN MOE GATE
# ==============================================================================
def pretrain_moe_gate(
adapter_library: List[ExpertAdapter],
in_features: int = 512,
gate_steps: int = 200,
) -> None:
"""Trainiert den LearnableGate MLP mit synthetischen Domain-Routing-Labels."""
print("\n" + "=" * 60)
print("STEP 3: PRETRAIN MoE GATE")
print("=" * 60)
num_adapters = len(adapter_library)
if num_adapters == 0:
print(" No adapters in library, skipping.")
return
gate = LearnableGate(in_features, num_adapters).to(DEVICE)
gate_optimizer = torch.optim.AdamW(gate.parameters(), lr=1e-3)
# Build synthetic supervision: for each adapter, its dominant domain
# determines what input pattern should route to it.
# We use one-hot targets based on domain tags.
adapter_targets = []
for adapter in adapter_library:
# Find dominant domain
if "statute_recall" in adapter.domain_tags:
d = 0
elif "logic_reasoning" in adapter.domain_tags:
d = 1
elif "style_gutachtenstil" in adapter.domain_tags:
d = 2
else:
d = 0
adapter_targets.append(d)
print(f" Training gate for {num_adapters} adapters, {gate_steps} steps...")
losses = []
for step in range(gate_steps):
# Synthetic input: each domain has a rough "signature" pattern
batch_size = 8
domain_inputs = []
domain_labels = []
for _ in range(batch_size):
# Pick a random adapter
adapter_idx = step % num_adapters
dom = adapter_targets[adapter_idx]
# Domain-specific input pattern (synthetic)
x = torch.randn(in_features, device=DEVICE)
# Amplify features in the domain's "direction"
x[dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0
domain_inputs.append(x)
# Target: one-hot over adapters with this adapter having highest weight
# We want gate to learn to output high score for adapters of this domain
matching_adapters = [i for i, t in enumerate(adapter_targets) if t == dom]
target = torch.zeros(num_adapters, device=DEVICE)
for idx in matching_adapters:
target[idx] = 1.0 / max(1, len(matching_adapters))
domain_labels.append(target)
x_batch = torch.stack(domain_inputs) # [B, in_features]
y_batch = torch.stack(domain_labels) # [B, num_adapters]
gate_optimizer.zero_grad()
logits = gate(x_batch) # [B, num_adapters]
loss = F.kl_div(
F.log_softmax(logits, dim=-1),
y_batch,
reduction="batchmean",
)
loss.backward()
gate_optimizer.step()
losses.append(float(loss.item()))
if step % 50 == 0:
print(f" Gate step {step}/{gate_steps} | loss={loss.item():.4f}")
gate_path = os.path.join(OUTPUT_DIR, "step3_gate_weights.pt")
torch.save(gate.state_dict(), gate_path)
final_loss = sum(losses[-20:]) / 20
print(f"\n✅ Step 3 done. Gate loss={final_loss:.4f}{gate_path}")
# ==============================================================================
# STEP 4: FCES WARM-START CHECKPOINT
# ==============================================================================
def build_fces_warm_start(
model_name: str,
best_cfg: HParamConfig,
dataset: List[Tuple[torch.Tensor, ...]],
warm_steps: int = 30,
) -> None:
"""Führt FCES für warm_steps durch, speichert Checkpoint und Population-Metadaten.
Da FCESOptimizer keine vollständige Population-Serialisierung hat (TODO in C++),
speichern wir:
1. Model weights (bestes Checkpoint nach letztem Roll-Forward)
2. Aktive Controller-Metadata: ID, fitness, step_count
3. Optimale Hyperparameter für RunPod als JSON
"""
print("\n" + "=" * 60)
print("STEP 4: FCES WARM-START")
print("=" * 60)
# Try to import fces_native; fall back to AdamW warm-start if not available
try:
import fces_native # noqa: F401
fces_available = True
except ImportError:
print(" [WARN] fces_native not compiled. Using AdamW warm-start instead.")
fces_available = False
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
model.train()
controller_log: List[Dict[str, Any]] = []
best_loss = float("inf")
best_step = 0
if fces_available:
import fces_native
cfg = fces_native.FCESConfig()
cfg.lr = best_cfg.lr
cfg.population_size = best_cfg.population_size
cfg.total_steps = warm_steps
optimizer_fces = fces_native.FCESOptimizer(list(model.parameters()), cfg)
print(f" Running FCES warm-start for {warm_steps} steps...")
for step in range(warm_steps):
batch = dataset[step % len(dataset)]
input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch]
optimizer_fces.zero_grad()
outputs = model(input_win, labels=labels_win)
loss = outputs.loss
loss.backward()
optimizer_fces.step()
optimizer_fces.update_fitness(float(loss.item()))
ctrl_id = optimizer_fces.get_active_controller_id()
ctrl_fit = optimizer_fces.get_active_controller_fitness()
controller_log.append(
{
"step": step,
"loss": float(loss.item()),
"controller_id": int(ctrl_id),
"controller_fitness": float(ctrl_fit),
}
)
if float(loss.item()) < best_loss:
best_loss = float(loss.item())
best_step = step
if step % 10 == 0:
print(
f" Step {step:3d} | loss={loss.item():.4f} | "
f"ctrl_id={ctrl_id} | fitness={ctrl_fit:.4f}"
)
else:
# AdamW warm-start as fallback
optimizer_adam = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr)
for step in range(warm_steps):
batch = dataset[step % len(dataset)]
input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch]
optimizer_adam.zero_grad()
outputs = model(input_win, labels=labels_win)
loss = outputs.loss
loss.backward()
optimizer_adam.step()
if float(loss.item()) < best_loss:
best_loss = float(loss.item())
best_step = step
if step % 10 == 0:
print(f" Step {step:3d} | loss={loss.item():.4f}")
# Save warm checkpoint
ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint")
os.makedirs(ckpt_dir, exist_ok=True)
model.save_pretrained(ckpt_dir)
# Save warm-start metadata
meta = {
"model_name": model_name,
"warm_steps": warm_steps,
"best_loss": best_loss,
"best_step": best_step,
"fces_available": fces_available,
"best_hparams": asdict(best_cfg),
"controller_log": controller_log,
"runpod_hint": (
f"Load from {ckpt_dir}. FCES already explored {warm_steps} steps. "
f"Resume from step {warm_steps} with lr={best_cfg.lr}, "
f"population_size={best_cfg.population_size}."
),
}
meta_path = os.path.join(OUTPUT_DIR, "step4_warm_metadata.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
print(
f"\n✅ Step 4 done. Warm checkpoint (best_loss={best_loss:.4f} @ step {best_step})"
)
print(f" Checkpoint: {ckpt_dir}")
print(f" Metadata: {meta_path}")
del model
if DEVICE == "cuda":
torch.cuda.empty_cache()
# ==============================================================================
# STEP 5: TOKENIZE DATASET + ORPO PAIRS
# ==============================================================================
def build_full_dataset(model_name: str) -> None:
"""Tokenisiert alle Klausuren-Cases als ORPO-Pairs und speichert sie als .pt."""
from transformers import AutoTokenizer
print("\n" + "=" * 60)
print("STEP 5: TOKENIZE DATASET + BUILD ORPO PAIRS")
print("=" * 60)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Build at multiple sequence lengths for flexible training
datasets: Dict[str, Any] = {}
for max_len in [64, 128, 256]:
pairs = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=max_len)
datasets[f"max_len_{max_len}"] = [
(iw.squeeze(0), lw.squeeze(0), il.squeeze(0), ll.squeeze(0))
for (iw, lw, il, ll) in pairs
]
print(f" max_len={max_len}: {len(pairs)} pairs built")
# Save metadata
case_meta = [
{
"id": c["id"],
"domain": c["domain"],
"statutes": c["statutes"],
"description_len": len(c["description"]),
}
for c in KLAUSUREN_CASES
]
datasets["metadata"] = {
"model_name": model_name,
"num_cases": len(KLAUSUREN_CASES),
"cases": case_meta,
"domains": list({c["domain"] for c in KLAUSUREN_CASES}),
"vocab_size": tokenizer.vocab_size,
}
out_path = os.path.join(OUTPUT_DIR, "step5_dataset.pt")
torch.save(datasets, out_path)
size_mb = os.path.getsize(out_path) / 1024 / 1024
print(f"\n✅ Step 5 done. Dataset ({size_mb:.2f} MB) → {out_path}")
print(f" Domains: {datasets['metadata']['domains']}")
# ==============================================================================
# MAIN PIPELINE
# ==============================================================================
def main() -> None:
import argparse
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(description="Kaggle Preparation Pipeline")
parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m")
parser.add_argument(
"--sweep-steps", type=int, default=15, help="Steps per hyperparameter sweep run"
)
parser.add_argument(
"--adapter-steps", type=int, default=50, help="Steps for adapter library build"
)
parser.add_argument(
"--warm-steps", type=int, default=30, help="Steps for FCES warm-start"
)
parser.add_argument(
"--skip-steps",
type=str,
default="",
help="Comma-separated step numbers to skip, e.g. '1,3'",
)
args = parser.parse_args()
skip = {int(s) for s in args.skip_steps.split(",") if s.strip()}
t_total = time.perf_counter()
print("\n" + "=" * 70)
print(" KAGGLE PREPARATION PIPELINE — FCES-native")
print(f" Model: {args.model} | Device: {DEVICE}")
print("=" * 70)
# Build shared tokenized dataset (used across all steps)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
shared_dataset = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=64)
# ── Step 1: Hyperparameter Sweep ─────────────────────────────────────────
if 1 not in skip:
best_cfg = run_hparam_sweep(args.model, shared_dataset, steps=args.sweep_steps)
else:
print("\n[SKIP] Step 1: Using default config")
best_cfg = HParamConfig(
lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95
)
# ── Step 2: Adapter Library ───────────────────────────────────────────────
if 2 not in skip:
adapter_library = build_adapter_library(
args.model, best_cfg, shared_dataset, steps=args.adapter_steps
)
else:
print("\n[SKIP] Step 2: Loading existing library")
lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt")
adapter_library = (
ParasiticQLoRAExtractor.load_library(lib_path)
if os.path.exists(lib_path)
else []
)
# ── Step 3: MoE Gate Pre-training ────────────────────────────────────────
if 3 not in skip:
# Detect in_features from model config
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(args.model)
in_features = getattr(config, "hidden_size", 512)
except Exception:
in_features = 512
pretrain_moe_gate(adapter_library, in_features=in_features, gate_steps=200)
else:
print("\n[SKIP] Step 3: Gate pre-training skipped")
# ── Step 4: FCES Warm-Start ───────────────────────────────────────────────
if 4 not in skip:
build_fces_warm_start(
args.model, best_cfg, shared_dataset, warm_steps=args.warm_steps
)
else:
print("\n[SKIP] Step 4: Warm-start skipped")
# ── Step 5: Full Dataset Tokenization ─────────────────────────────────────
if 5 not in skip:
build_full_dataset(args.model)
else:
print("\n[SKIP] Step 5: Dataset tokenization skipped")
# ── Summary ───────────────────────────────────────────────────────────────
total_time = time.perf_counter() - t_total
output_files = [f for f in os.listdir(OUTPUT_DIR) if not f.startswith(".")]
total_size_mb = (
sum(
os.path.getsize(os.path.join(OUTPUT_DIR, f))
for f in output_files
if os.path.isfile(os.path.join(OUTPUT_DIR, f))
)
/ 1024
/ 1024
)
summary = {
"model": args.model,
"device": DEVICE,
"total_wall_time_s": round(total_time, 1),
"output_files": output_files,
"total_size_mb": round(total_size_mb, 2),
"runpod_savings_hint": (
"Upload /output/ as Kaggle Dataset. "
"RunPod: load step1_best_config.json for hparams, "
"step2_adapter_library.pt for warm adapter library, "
"step3_gate_weights.pt for MoE router, "
"step4_warm_checkpoint/ as starting model weights, "
"step5_dataset.pt for pre-tokenized training data. "
"Estimated RunPod time savings: 60-80% vs cold start."
),
}
summary_path = os.path.join(OUTPUT_DIR, "prep_summary.json")
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)
print("\n" + "=" * 70)
print(" KAGGLE PREP COMPLETE")
print("=" * 70)
print(f" Total time : {total_time:.1f}s")
print(f" Artifacts : {len(output_files)} files, {total_size_mb:.1f} MB")
print(f" Output dir : {OUTPUT_DIR}")
print("\n RunPod Ladeanweisung:")
print(" kaggle datasets download sven/fces-prep -p /data/")
for fname in sorted(output_files):
size = os.path.getsize(os.path.join(OUTPUT_DIR, fname)) / 1024
if os.path.isfile(os.path.join(OUTPUT_DIR, fname)):
print(f" /data/{fname:<35} ({size:.0f} KB)")
print("=" * 70)
if __name__ == "__main__":
main()