# -*- coding: utf-8 -*- """Kaggle Preparation Pipeline for FCES-native. Führt alle 5 Vorbereitungsschritte auf Kaggle T4/P100 durch, bevor teure RunPod A100-Zeit genutzt wird. Alle Artefakte werden in /kaggle/working/ gespeichert und können als Kaggle Dataset veröffentlicht werden. Ausführung auf Kaggle: !git clone https://git.zky.de/sven/FCES-native !cd FCES-native && pip install -r requirements.txt !cd FCES-native && python python/setup.py build_ext --inplace !cd FCES-native && python kaggle_prep.py Produzierte Artefakte: /output/step1_best_config.json ← Beste Hyperparameter /output/step2_adapter_library.pt ← Pre-built Expert Adapter Library /output/step3_gate_weights.pt ← Vortrainierte MoE Gate Weights /output/step4_warm_checkpoint/ ← FCES Warm-Start Checkpoint /output/step5_dataset.pt ← Tokenisiertes ORPO-Dataset /output/prep_summary.json ← Zusammenfassung aller Schritte """ from __future__ import annotations import json import os import sys import time from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F # Ensure python/ is importable _root = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(_root, "python")) from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402 from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402 from adapter_moe_router import LearnableGate # noqa: E402 OUTPUT_DIR = "/kaggle/working/output" if not os.path.exists("/kaggle"): OUTPUT_DIR = os.path.join(_root, "kaggle_output") os.makedirs(OUTPUT_DIR, exist_ok=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[Kaggle Prep] Device: {DEVICE} | Output: {OUTPUT_DIR}") # ============================================================================== # LEGAL KLAUSUREN CASES (expanded for Kaggle prep) # ============================================================================== KLAUSUREN_CASES: List[Dict[str, str]] = [ # --- Mietrecht --- { "id": "miet_01", "description": ( "Mieter M besichtigt im Januar eine Wohnung des Vermieters V. " "Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. " "M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. " "Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. " "V verlangt volle Miete." ), "statutes": "§ 535 Abs. 2 BGB, § 536 Abs. 1 BGB, § 536b BGB", "correct": ( "M hat gemäß § 536b BGB das Minderungsrecht verloren, da er bei " "Vertragsschluss positive Kenntnis vom Schimmelfleck hatte und " "vorbehaltlos unterschrieben hat. V kann volle Miete verlangen." ), "flawed": ( "M kann die Miete mindern, weil Schimmel ein Sachmangel ist. " "§ 536b BGB ist irrelevant, da V eine Zusage gemacht hat." ), "domain": "statute_recall", }, { "id": "miet_02", "description": ( "Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft " "ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht " "erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den " "Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich." ), "statutes": "§ 536a Abs. 2 Nr. 2 BGB, § 535 BGB, § 683 BGB", "correct": ( "F hat nach § 536a Abs. 2 Nr. 2 BGB das Selbstbeseitigungsrecht bei " "Gefahr im Verzug. V war nicht erreichbar, der Schaden drohend. " "F kann 300 € Erstattung verlangen." ), "flawed": ( "F hätte länger auf V warten müssen. Das Selbstbeseitigungsrecht " "besteht nur bei völliger Unerreichbarkeit über mehrere Tage." ), "domain": "logic_reasoning", }, { "id": "miet_03", "description": ( "Vermieter V kündigt dem Mieter M ordentlich wegen Eigenbedarfs (§ 573 BGB). " "M ist 78 Jahre alt und schwer krank. Er beruft sich auf § 574 BGB " "(Sozialklausel). V bestreitet die Härte." ), "statutes": "§ 573 Abs. 2 Nr. 2 BGB, § 574 BGB, § 574a BGB", "correct": ( "Die Kündigung ist formell wirksam (Eigenbedarf). Jedoch kann M gemäß " "§ 574 BGB Widerspruch einlegen. Das Gericht wägt ab: Alter 78, schwere " "Krankheit begründen in der Regel eine unzumutbare Härte. Kündigung " "wird voraussichtlich ausgesetzt oder Fortsetzung angeordnet." ), "flawed": ( "Eigenbedarf schlägt immer die Sozialklausel. M muss ausziehen. " "§ 574 BGB gilt nur bei Vermietern ohne eigentlichen Bedarf." ), "domain": "logic_reasoning", }, # --- Kaufrecht --- { "id": "kauf_01", "description": ( "K kauft bei V einen Gebrauchtwagen für 8.000 €. Nach 3 Monaten " "stellt sich heraus, dass der Kilometerzähler manipuliert war. " "Tatsächlich 180.000 km statt 80.000 km. V beruft sich auf " "'gekauft wie gesehen' Klausel." ), "statutes": "§ 434 BGB, § 437 BGB, § 444 BGB, § 476 BGB", "correct": ( "Die 'gekauft wie gesehen' Klausel schließt Gewährleistung nicht aus, " "wenn arglistige Täuschung vorliegt (§ 444 BGB). Kilometerzähler-" "manipulation ist arglistige Täuschung → Klausel unwirksam. " "K kann Rücktritt (§ 437 Nr. 2 BGB) oder Schadensersatz verlangen." ), "flawed": ( "'Gekauft wie gesehen' schließt alle Mängel aus. K hatte Gelegenheit " "zur Besichtigung und muss das Fahrzeug so akzeptieren wie es ist." ), "domain": "statute_recall", }, { "id": "kauf_02", "description": ( "Unternehmerin U bestellt bei Lieferant L 500 Stühle für ein Büro. " "Lieferdatum: 15. März. Am 20. März noch keine Lieferung. U setzt " "Nachfrist bis 25. März. L liefert nicht. U kauft Stühle teurer " "woanders und fordert Mehrkosten von 3.000 € von L." ), "statutes": "§ 280 Abs. 1, 3 BGB, § 281 BGB, § 323 BGB, § 286 BGB", "correct": ( "L ist durch Überschreiten des Liefertermins in Verzug (§ 286 BGB). " "Nach erfolglosem Ablauf der Nachfrist kann U nach § 281 BGB " "Schadensersatz statt der Leistung verlangen. Die 3.000 € " "Mehrkosten sind als Deckungsschaden erstattungsfähig." ), "flawed": ( "U muss zunächst Klage erheben und kann erst nach rechtskräftigem Urteil " "Ersatz verlangen. Die Nachfrist allein reicht nicht aus." ), "domain": "style_gutachtenstil", }, # --- Deliktsrecht --- { "id": "delikt_01", "description": ( "A fährt nachts betrunken (1,5 Promille) mit seinem Auto und fährt " "das geparkte Auto des B zu. Schaden 5.000 €. A hat keine " "Kfz-Haftpflichtversicherung. B verlangt Schadensersatz." ), "statutes": "§ 823 Abs. 1 BGB, § 827 BGB, § 828 BGB, § 7 StVG", "correct": ( "A haftet nach § 823 Abs. 1 BGB (Eigentumsverletzung). Trunkenheit " "schließt die Deliktsfähigkeit nicht aus, da sie selbst herbeigeführt " "wurde (actio libera in causa). Zudem haftet A nach § 7 StVG als " "Halter/Fahrer. Vollständiger Schadensersatz 5.000 €." ), "flawed": ( "A ist wegen Trunkenheit schuldunfähig (§ 827 BGB) und haftet nicht. " "B muss seinen Schaden selbst tragen oder die eigene Versicherung " "bemühen." ), "domain": "logic_reasoning", }, ] # ORPO Preference Pairs: (preferred=correct, dispreferred=flawed) def build_orpo_pairs( cases: List[Dict[str, str]], tokenizer: Any, max_length: int = 128 ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """Baut ORPO-Trainingspaare aus den Klausuren-Cases.""" dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] for case in cases: prompt = ( f"Fall: {case['description']}\n" f"Normen: {case['statutes']}\n" "Analysiere im Gutachtenstil (FIRT):\n" ) text_win = prompt + case["correct"] text_lose = prompt + case["flawed"] enc_win = tokenizer( text_win, truncation=True, max_length=max_length, padding="max_length", return_tensors="pt", ) enc_lose = tokenizer( text_lose, truncation=True, max_length=max_length, padding="max_length", return_tensors="pt", ) labels_win = enc_win["input_ids"].clone() labels_win[enc_win["attention_mask"] == 0] = -100 labels_lose = enc_lose["input_ids"].clone() labels_lose[enc_lose["attention_mask"] == 0] = -100 dataset.append( ( enc_win["input_ids"], labels_win, enc_lose["input_ids"], labels_lose, ) ) return dataset # ============================================================================== # STEP 1: HYPERPARAMETER SWEEP # ============================================================================== @dataclass class HParamConfig: lr: float population_size: int rank_min: int rank_max: int explained_variance: float def run_hparam_sweep( model_name: str, dataset: List[Tuple[torch.Tensor, ...]], steps: int = 15, ) -> HParamConfig: """Grid Search über kritische Hyperparameter auf Kaggle T4.""" from transformers import AutoModelForCausalLM configs: List[HParamConfig] = [ HParamConfig( lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95 ), HParamConfig( lr=5e-5, population_size=16, rank_min=16, rank_max=64, explained_variance=0.95, ), HParamConfig( lr=2e-4, population_size=4, rank_min=4, rank_max=16, explained_variance=0.90 ), HParamConfig( lr=1e-4, population_size=8, rank_min=8, rank_max=64, explained_variance=0.97 ), HParamConfig( lr=8e-5, population_size=12, rank_min=8, rank_max=32, explained_variance=0.95, ), ] best_config: Optional[HParamConfig] = None best_score = float("inf") results = [] print("\n" + "=" * 60) print("STEP 1: HYPERPARAMETER SWEEP") print("=" * 60) for i, cfg in enumerate(configs): print( f"\n[Sweep {i+1}/{len(configs)}] lr={cfg.lr}, pop={cfg.population_size}, " f"rank=[{cfg.rank_min},{cfg.rank_max}]" ) model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr) extractor = ParasiticQLoRAExtractor( QLoRAConfig( min_rank=cfg.rank_min, max_rank=cfg.rank_max, explained_variance_threshold=cfg.explained_variance, extraction_interval=steps, interesting_point_detection=False, ) ) extractor.snapshot_base(model) model.train() losses = [] t0 = time.perf_counter() for step in range(steps): batch = dataset[step % len(dataset)] input_win, labels_win, input_lose, labels_lose = [ t.to(DEVICE) for t in batch ] optimizer.zero_grad() outputs = model(input_win, labels=labels_win) loss = outputs.loss loss.backward() optimizer.step() losses.append(float(loss.item())) wall_time = time.perf_counter() - t0 # Extract one adapter to measure quality adapter = extractor.extract_adapters( model, steps, {"loss": losses[-1], "optimizer": "AdamW"} ) avg_ev = adapter.avg_explained_variance final_loss = losses[-1] score = final_loss * (1.0 - avg_ev) # lower is better print( f" -> final_loss={final_loss:.4f} | avg_ev={avg_ev:.3f} | " f"score={score:.4f} | time={wall_time:.1f}s" ) results.append( { "config": asdict(cfg), "score": score, "final_loss": final_loss, "avg_ev": avg_ev, } ) if score < best_score: best_score = score best_config = cfg del model if DEVICE == "cuda": torch.cuda.empty_cache() out_path = os.path.join(OUTPUT_DIR, "step1_best_config.json") with open(out_path, "w", encoding="utf-8") as f: json.dump({"best": asdict(best_config), "all_results": results}, f, indent=2) # type: ignore print(f"\n✅ Step 1 done. Best: {best_config} → saved to {out_path}") return best_config # type: ignore # ============================================================================== # STEP 2: BUILD ADAPTER LIBRARY ON SMALL MODELS # ============================================================================== def build_adapter_library( model_name: str, best_cfg: HParamConfig, dataset: List[Tuple[torch.Tensor, ...]], steps: int = 50, ) -> List[ExpertAdapter]: """Trainiert Pythia-70m/160m auf Kaggle und baut eine diverse Adapter Library auf.""" from transformers import AutoModelForCausalLM print("\n" + "=" * 60) print("STEP 2: BUILD ADAPTER LIBRARY") print("=" * 60) model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) aligner = ExpertManifoldAligner(model) extractor = ParasiticQLoRAExtractor( QLoRAConfig( min_rank=best_cfg.rank_min, max_rank=best_cfg.rank_max, explained_variance_threshold=best_cfg.explained_variance, extraction_interval=max(5, steps // 10), # Extract ~10 checkpoints interesting_point_detection=True, ) ) extractor.snapshot_base(model) optimizer = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr) model.train() print( f"Training {model_name} for {steps} steps, extracting adapters every " f"{extractor.config.extraction_interval} steps..." ) for step in range(steps): batch = dataset[step % len(dataset)] input_win, labels_win, input_lose, labels_lose = [t.to(DEVICE) for t in batch] optimizer.zero_grad() outputs = model(input_win, labels=labels_win) loss = outputs.loss loss.backward() optimizer.step() aligner.track_step(model) if extractor.should_extract(step, float(loss.item())): metrics: Dict[str, Any] = { "loss": float(loss.item()), "optimizer": "AdamW", } adapter = extractor.extract_adapters(model, step, metrics) aligner.tag_adapter(adapter) profile = aligner.profile_adapter(adapter) print( f" Step {step:3d} | loss={loss.item():.4f} | " f"tags={adapter.domain_tags} | " f"statute={profile['statute_recall']:.2f} " f"logic={profile['logic_reasoning']:.2f} " f"style={profile['style_gutachtenstil']:.2f}" ) if step % 10 == 0: print(f" Step {step}/{steps} | loss={loss.item():.4f}") # Save checkpoint for warm-start (Step 4 will also use this) ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint") os.makedirs(ckpt_dir, exist_ok=True) model.save_pretrained(ckpt_dir) print(f" Checkpoint saved to {ckpt_dir}") # Save adapter library lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt") extractor.save_library(lib_path) print( f"\n✅ Step 2 done. Library: {len(extractor.adapter_library)} adapters → {lib_path}" ) del model if DEVICE == "cuda": torch.cuda.empty_cache() return list(extractor.adapter_library) # ============================================================================== # STEP 3: PRETRAIN MOE GATE # ============================================================================== def pretrain_moe_gate( adapter_library: List[ExpertAdapter], in_features: int = 512, gate_steps: int = 200, ) -> None: """Trainiert den LearnableGate MLP mit synthetischen Domain-Routing-Labels.""" print("\n" + "=" * 60) print("STEP 3: PRETRAIN MoE GATE") print("=" * 60) num_adapters = len(adapter_library) if num_adapters == 0: print(" No adapters in library, skipping.") return gate = LearnableGate(in_features, num_adapters).to(DEVICE) gate_optimizer = torch.optim.AdamW(gate.parameters(), lr=1e-3) # Build synthetic supervision: for each adapter, its dominant domain # determines what input pattern should route to it. # We use one-hot targets based on domain tags. adapter_targets = [] for adapter in adapter_library: # Find dominant domain if "statute_recall" in adapter.domain_tags: d = 0 elif "logic_reasoning" in adapter.domain_tags: d = 1 elif "style_gutachtenstil" in adapter.domain_tags: d = 2 else: d = 0 adapter_targets.append(d) print(f" Training gate for {num_adapters} adapters, {gate_steps} steps...") losses = [] for step in range(gate_steps): # Synthetic input: each domain has a rough "signature" pattern batch_size = 8 domain_inputs = [] domain_labels = [] for _ in range(batch_size): # Pick a random adapter adapter_idx = step % num_adapters dom = adapter_targets[adapter_idx] # Domain-specific input pattern (synthetic) x = torch.randn(in_features, device=DEVICE) # Amplify features in the domain's "direction" x[dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0 domain_inputs.append(x) # Target: one-hot over adapters with this adapter having highest weight # We want gate to learn to output high score for adapters of this domain matching_adapters = [i for i, t in enumerate(adapter_targets) if t == dom] target = torch.zeros(num_adapters, device=DEVICE) for idx in matching_adapters: target[idx] = 1.0 / max(1, len(matching_adapters)) domain_labels.append(target) x_batch = torch.stack(domain_inputs) # [B, in_features] y_batch = torch.stack(domain_labels) # [B, num_adapters] gate_optimizer.zero_grad() logits = gate(x_batch) # [B, num_adapters] loss = F.kl_div( F.log_softmax(logits, dim=-1), y_batch, reduction="batchmean", ) loss.backward() gate_optimizer.step() losses.append(float(loss.item())) if step % 50 == 0: print(f" Gate step {step}/{gate_steps} | loss={loss.item():.4f}") gate_path = os.path.join(OUTPUT_DIR, "step3_gate_weights.pt") torch.save(gate.state_dict(), gate_path) final_loss = sum(losses[-20:]) / 20 print(f"\n✅ Step 3 done. Gate loss={final_loss:.4f} → {gate_path}") # ============================================================================== # STEP 4: FCES WARM-START CHECKPOINT # ============================================================================== def build_fces_warm_start( model_name: str, best_cfg: HParamConfig, dataset: List[Tuple[torch.Tensor, ...]], warm_steps: int = 30, ) -> None: """Führt FCES für warm_steps durch, speichert Checkpoint und Population-Metadaten. Da FCESOptimizer keine vollständige Population-Serialisierung hat (TODO in C++), speichern wir: 1. Model weights (bestes Checkpoint nach letztem Roll-Forward) 2. Aktive Controller-Metadata: ID, fitness, step_count 3. Optimale Hyperparameter für RunPod als JSON """ print("\n" + "=" * 60) print("STEP 4: FCES WARM-START") print("=" * 60) # Try to import fces_native; fall back to AdamW warm-start if not available try: import fces_native # noqa: F401 fces_available = True except ImportError: print(" [WARN] fces_native not compiled. Using AdamW warm-start instead.") fces_available = False from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) model.train() controller_log: List[Dict[str, Any]] = [] best_loss = float("inf") best_step = 0 if fces_available: import fces_native cfg = fces_native.FCESConfig() cfg.lr = best_cfg.lr cfg.population_size = best_cfg.population_size cfg.total_steps = warm_steps optimizer_fces = fces_native.FCESOptimizer(list(model.parameters()), cfg) print(f" Running FCES warm-start for {warm_steps} steps...") for step in range(warm_steps): batch = dataset[step % len(dataset)] input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch] optimizer_fces.zero_grad() outputs = model(input_win, labels=labels_win) loss = outputs.loss loss.backward() optimizer_fces.step() optimizer_fces.update_fitness(float(loss.item())) ctrl_id = optimizer_fces.get_active_controller_id() ctrl_fit = optimizer_fces.get_active_controller_fitness() controller_log.append( { "step": step, "loss": float(loss.item()), "controller_id": int(ctrl_id), "controller_fitness": float(ctrl_fit), } ) if float(loss.item()) < best_loss: best_loss = float(loss.item()) best_step = step if step % 10 == 0: print( f" Step {step:3d} | loss={loss.item():.4f} | " f"ctrl_id={ctrl_id} | fitness={ctrl_fit:.4f}" ) else: # AdamW warm-start as fallback optimizer_adam = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr) for step in range(warm_steps): batch = dataset[step % len(dataset)] input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch] optimizer_adam.zero_grad() outputs = model(input_win, labels=labels_win) loss = outputs.loss loss.backward() optimizer_adam.step() if float(loss.item()) < best_loss: best_loss = float(loss.item()) best_step = step if step % 10 == 0: print(f" Step {step:3d} | loss={loss.item():.4f}") # Save warm checkpoint ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint") os.makedirs(ckpt_dir, exist_ok=True) model.save_pretrained(ckpt_dir) # Save warm-start metadata meta = { "model_name": model_name, "warm_steps": warm_steps, "best_loss": best_loss, "best_step": best_step, "fces_available": fces_available, "best_hparams": asdict(best_cfg), "controller_log": controller_log, "runpod_hint": ( f"Load from {ckpt_dir}. FCES already explored {warm_steps} steps. " f"Resume from step {warm_steps} with lr={best_cfg.lr}, " f"population_size={best_cfg.population_size}." ), } meta_path = os.path.join(OUTPUT_DIR, "step4_warm_metadata.json") with open(meta_path, "w", encoding="utf-8") as f: json.dump(meta, f, indent=2) print( f"\n✅ Step 4 done. Warm checkpoint (best_loss={best_loss:.4f} @ step {best_step})" ) print(f" Checkpoint: {ckpt_dir}") print(f" Metadata: {meta_path}") del model if DEVICE == "cuda": torch.cuda.empty_cache() # ============================================================================== # STEP 5: TOKENIZE DATASET + ORPO PAIRS # ============================================================================== def build_full_dataset(model_name: str) -> None: """Tokenisiert alle Klausuren-Cases als ORPO-Pairs und speichert sie als .pt.""" from transformers import AutoTokenizer print("\n" + "=" * 60) print("STEP 5: TOKENIZE DATASET + BUILD ORPO PAIRS") print("=" * 60) tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Build at multiple sequence lengths for flexible training datasets: Dict[str, Any] = {} for max_len in [64, 128, 256]: pairs = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=max_len) datasets[f"max_len_{max_len}"] = [ (iw.squeeze(0), lw.squeeze(0), il.squeeze(0), ll.squeeze(0)) for (iw, lw, il, ll) in pairs ] print(f" max_len={max_len}: {len(pairs)} pairs built") # Save metadata case_meta = [ { "id": c["id"], "domain": c["domain"], "statutes": c["statutes"], "description_len": len(c["description"]), } for c in KLAUSUREN_CASES ] datasets["metadata"] = { "model_name": model_name, "num_cases": len(KLAUSUREN_CASES), "cases": case_meta, "domains": list({c["domain"] for c in KLAUSUREN_CASES}), "vocab_size": tokenizer.vocab_size, } out_path = os.path.join(OUTPUT_DIR, "step5_dataset.pt") torch.save(datasets, out_path) size_mb = os.path.getsize(out_path) / 1024 / 1024 print(f"\n✅ Step 5 done. Dataset ({size_mb:.2f} MB) → {out_path}") print(f" Domains: {datasets['metadata']['domains']}") # ============================================================================== # MAIN PIPELINE # ============================================================================== def main() -> None: import argparse from transformers import AutoTokenizer parser = argparse.ArgumentParser(description="Kaggle Preparation Pipeline") parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m") parser.add_argument( "--sweep-steps", type=int, default=15, help="Steps per hyperparameter sweep run" ) parser.add_argument( "--adapter-steps", type=int, default=50, help="Steps for adapter library build" ) parser.add_argument( "--warm-steps", type=int, default=30, help="Steps for FCES warm-start" ) parser.add_argument( "--skip-steps", type=str, default="", help="Comma-separated step numbers to skip, e.g. '1,3'", ) args = parser.parse_args() skip = {int(s) for s in args.skip_steps.split(",") if s.strip()} t_total = time.perf_counter() print("\n" + "=" * 70) print(" KAGGLE PREPARATION PIPELINE — FCES-native") print(f" Model: {args.model} | Device: {DEVICE}") print("=" * 70) # Build shared tokenized dataset (used across all steps) tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token shared_dataset = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=64) # ── Step 1: Hyperparameter Sweep ───────────────────────────────────────── if 1 not in skip: best_cfg = run_hparam_sweep(args.model, shared_dataset, steps=args.sweep_steps) else: print("\n[SKIP] Step 1: Using default config") best_cfg = HParamConfig( lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95 ) # ── Step 2: Adapter Library ─────────────────────────────────────────────── if 2 not in skip: adapter_library = build_adapter_library( args.model, best_cfg, shared_dataset, steps=args.adapter_steps ) else: print("\n[SKIP] Step 2: Loading existing library") lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt") adapter_library = ( ParasiticQLoRAExtractor.load_library(lib_path) if os.path.exists(lib_path) else [] ) # ── Step 3: MoE Gate Pre-training ──────────────────────────────────────── if 3 not in skip: # Detect in_features from model config try: from transformers import AutoConfig config = AutoConfig.from_pretrained(args.model) in_features = getattr(config, "hidden_size", 512) except Exception: in_features = 512 pretrain_moe_gate(adapter_library, in_features=in_features, gate_steps=200) else: print("\n[SKIP] Step 3: Gate pre-training skipped") # ── Step 4: FCES Warm-Start ─────────────────────────────────────────────── if 4 not in skip: build_fces_warm_start( args.model, best_cfg, shared_dataset, warm_steps=args.warm_steps ) else: print("\n[SKIP] Step 4: Warm-start skipped") # ── Step 5: Full Dataset Tokenization ───────────────────────────────────── if 5 not in skip: build_full_dataset(args.model) else: print("\n[SKIP] Step 5: Dataset tokenization skipped") # ── Summary ─────────────────────────────────────────────────────────────── total_time = time.perf_counter() - t_total output_files = [f for f in os.listdir(OUTPUT_DIR) if not f.startswith(".")] total_size_mb = ( sum( os.path.getsize(os.path.join(OUTPUT_DIR, f)) for f in output_files if os.path.isfile(os.path.join(OUTPUT_DIR, f)) ) / 1024 / 1024 ) summary = { "model": args.model, "device": DEVICE, "total_wall_time_s": round(total_time, 1), "output_files": output_files, "total_size_mb": round(total_size_mb, 2), "runpod_savings_hint": ( "Upload /output/ as Kaggle Dataset. " "RunPod: load step1_best_config.json for hparams, " "step2_adapter_library.pt for warm adapter library, " "step3_gate_weights.pt for MoE router, " "step4_warm_checkpoint/ as starting model weights, " "step5_dataset.pt for pre-tokenized training data. " "Estimated RunPod time savings: 60-80% vs cold start." ), } summary_path = os.path.join(OUTPUT_DIR, "prep_summary.json") with open(summary_path, "w", encoding="utf-8") as f: json.dump(summary, f, indent=2) print("\n" + "=" * 70) print(" KAGGLE PREP COMPLETE") print("=" * 70) print(f" Total time : {total_time:.1f}s") print(f" Artifacts : {len(output_files)} files, {total_size_mb:.1f} MB") print(f" Output dir : {OUTPUT_DIR}") print("\n RunPod Ladeanweisung:") print(" kaggle datasets download sven/fces-prep -p /data/") for fname in sorted(output_files): size = os.path.getsize(os.path.join(OUTPUT_DIR, fname)) / 1024 if os.path.isfile(os.path.join(OUTPUT_DIR, fname)): print(f" /data/{fname:<35} ({size:.0f} KB)") print("=" * 70) if __name__ == "__main__": main()