feat: kaggle preparation pipeline - 5-step RunPod cost optimization
This commit is contained in:
875
kaggle_prep.py
Normal file
875
kaggle_prep.py
Normal file
@@ -0,0 +1,875 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Kaggle Preparation Pipeline for FCES-native.
|
||||
|
||||
Führt alle 5 Vorbereitungsschritte auf Kaggle T4/P100 durch, bevor
|
||||
teure RunPod A100-Zeit genutzt wird. Alle Artefakte werden in /kaggle/working/
|
||||
gespeichert und können als Kaggle Dataset veröffentlicht werden.
|
||||
|
||||
Ausführung auf Kaggle:
|
||||
!git clone https://git.zky.de/sven/FCES-native
|
||||
!cd FCES-native && pip install -r requirements.txt
|
||||
!cd FCES-native && python python/setup.py build_ext --inplace
|
||||
!cd FCES-native && python kaggle_prep.py
|
||||
|
||||
Produzierte Artefakte:
|
||||
/output/step1_best_config.json ← Beste Hyperparameter
|
||||
/output/step2_adapter_library.pt ← Pre-built Expert Adapter Library
|
||||
/output/step3_gate_weights.pt ← Vortrainierte MoE Gate Weights
|
||||
/output/step4_warm_checkpoint/ ← FCES Warm-Start Checkpoint
|
||||
/output/step5_dataset.pt ← Tokenisiertes ORPO-Dataset
|
||||
/output/prep_summary.json ← Zusammenfassung aller Schritte
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Ensure python/ is importable
|
||||
_root = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(_root, "python"))
|
||||
|
||||
from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
||||
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
|
||||
from adapter_moe_router import LearnableGate # noqa: E402
|
||||
|
||||
OUTPUT_DIR = "/kaggle/working/output"
|
||||
if not os.path.exists("/kaggle"):
|
||||
OUTPUT_DIR = os.path.join(_root, "kaggle_output")
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"[Kaggle Prep] Device: {DEVICE} | Output: {OUTPUT_DIR}")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# LEGAL KLAUSUREN CASES (expanded for Kaggle prep)
|
||||
# ==============================================================================
|
||||
|
||||
KLAUSUREN_CASES: List[Dict[str, str]] = [
|
||||
# --- Mietrecht ---
|
||||
{
|
||||
"id": "miet_01",
|
||||
"description": (
|
||||
"Mieter M besichtigt im Januar eine Wohnung des Vermieters V. "
|
||||
"Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. "
|
||||
"M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. "
|
||||
"Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. "
|
||||
"V verlangt volle Miete."
|
||||
),
|
||||
"statutes": "§ 535 Abs. 2 BGB, § 536 Abs. 1 BGB, § 536b BGB",
|
||||
"correct": (
|
||||
"M hat gemäß § 536b BGB das Minderungsrecht verloren, da er bei "
|
||||
"Vertragsschluss positive Kenntnis vom Schimmelfleck hatte und "
|
||||
"vorbehaltlos unterschrieben hat. V kann volle Miete verlangen."
|
||||
),
|
||||
"flawed": (
|
||||
"M kann die Miete mindern, weil Schimmel ein Sachmangel ist. "
|
||||
"§ 536b BGB ist irrelevant, da V eine Zusage gemacht hat."
|
||||
),
|
||||
"domain": "statute_recall",
|
||||
},
|
||||
{
|
||||
"id": "miet_02",
|
||||
"description": (
|
||||
"Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft "
|
||||
"ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht "
|
||||
"erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den "
|
||||
"Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich."
|
||||
),
|
||||
"statutes": "§ 536a Abs. 2 Nr. 2 BGB, § 535 BGB, § 683 BGB",
|
||||
"correct": (
|
||||
"F hat nach § 536a Abs. 2 Nr. 2 BGB das Selbstbeseitigungsrecht bei "
|
||||
"Gefahr im Verzug. V war nicht erreichbar, der Schaden drohend. "
|
||||
"F kann 300 € Erstattung verlangen."
|
||||
),
|
||||
"flawed": (
|
||||
"F hätte länger auf V warten müssen. Das Selbstbeseitigungsrecht "
|
||||
"besteht nur bei völliger Unerreichbarkeit über mehrere Tage."
|
||||
),
|
||||
"domain": "logic_reasoning",
|
||||
},
|
||||
{
|
||||
"id": "miet_03",
|
||||
"description": (
|
||||
"Vermieter V kündigt dem Mieter M ordentlich wegen Eigenbedarfs (§ 573 BGB). "
|
||||
"M ist 78 Jahre alt und schwer krank. Er beruft sich auf § 574 BGB "
|
||||
"(Sozialklausel). V bestreitet die Härte."
|
||||
),
|
||||
"statutes": "§ 573 Abs. 2 Nr. 2 BGB, § 574 BGB, § 574a BGB",
|
||||
"correct": (
|
||||
"Die Kündigung ist formell wirksam (Eigenbedarf). Jedoch kann M gemäß "
|
||||
"§ 574 BGB Widerspruch einlegen. Das Gericht wägt ab: Alter 78, schwere "
|
||||
"Krankheit begründen in der Regel eine unzumutbare Härte. Kündigung "
|
||||
"wird voraussichtlich ausgesetzt oder Fortsetzung angeordnet."
|
||||
),
|
||||
"flawed": (
|
||||
"Eigenbedarf schlägt immer die Sozialklausel. M muss ausziehen. "
|
||||
"§ 574 BGB gilt nur bei Vermietern ohne eigentlichen Bedarf."
|
||||
),
|
||||
"domain": "logic_reasoning",
|
||||
},
|
||||
# --- Kaufrecht ---
|
||||
{
|
||||
"id": "kauf_01",
|
||||
"description": (
|
||||
"K kauft bei V einen Gebrauchtwagen für 8.000 €. Nach 3 Monaten "
|
||||
"stellt sich heraus, dass der Kilometerzähler manipuliert war. "
|
||||
"Tatsächlich 180.000 km statt 80.000 km. V beruft sich auf "
|
||||
"'gekauft wie gesehen' Klausel."
|
||||
),
|
||||
"statutes": "§ 434 BGB, § 437 BGB, § 444 BGB, § 476 BGB",
|
||||
"correct": (
|
||||
"Die 'gekauft wie gesehen' Klausel schließt Gewährleistung nicht aus, "
|
||||
"wenn arglistige Täuschung vorliegt (§ 444 BGB). Kilometerzähler-"
|
||||
"manipulation ist arglistige Täuschung → Klausel unwirksam. "
|
||||
"K kann Rücktritt (§ 437 Nr. 2 BGB) oder Schadensersatz verlangen."
|
||||
),
|
||||
"flawed": (
|
||||
"'Gekauft wie gesehen' schließt alle Mängel aus. K hatte Gelegenheit "
|
||||
"zur Besichtigung und muss das Fahrzeug so akzeptieren wie es ist."
|
||||
),
|
||||
"domain": "statute_recall",
|
||||
},
|
||||
{
|
||||
"id": "kauf_02",
|
||||
"description": (
|
||||
"Unternehmerin U bestellt bei Lieferant L 500 Stühle für ein Büro. "
|
||||
"Lieferdatum: 15. März. Am 20. März noch keine Lieferung. U setzt "
|
||||
"Nachfrist bis 25. März. L liefert nicht. U kauft Stühle teurer "
|
||||
"woanders und fordert Mehrkosten von 3.000 € von L."
|
||||
),
|
||||
"statutes": "§ 280 Abs. 1, 3 BGB, § 281 BGB, § 323 BGB, § 286 BGB",
|
||||
"correct": (
|
||||
"L ist durch Überschreiten des Liefertermins in Verzug (§ 286 BGB). "
|
||||
"Nach erfolglosem Ablauf der Nachfrist kann U nach § 281 BGB "
|
||||
"Schadensersatz statt der Leistung verlangen. Die 3.000 € "
|
||||
"Mehrkosten sind als Deckungsschaden erstattungsfähig."
|
||||
),
|
||||
"flawed": (
|
||||
"U muss zunächst Klage erheben und kann erst nach rechtskräftigem Urteil "
|
||||
"Ersatz verlangen. Die Nachfrist allein reicht nicht aus."
|
||||
),
|
||||
"domain": "style_gutachtenstil",
|
||||
},
|
||||
# --- Deliktsrecht ---
|
||||
{
|
||||
"id": "delikt_01",
|
||||
"description": (
|
||||
"A fährt nachts betrunken (1,5 Promille) mit seinem Auto und fährt "
|
||||
"das geparkte Auto des B zu. Schaden 5.000 €. A hat keine "
|
||||
"Kfz-Haftpflichtversicherung. B verlangt Schadensersatz."
|
||||
),
|
||||
"statutes": "§ 823 Abs. 1 BGB, § 827 BGB, § 828 BGB, § 7 StVG",
|
||||
"correct": (
|
||||
"A haftet nach § 823 Abs. 1 BGB (Eigentumsverletzung). Trunkenheit "
|
||||
"schließt die Deliktsfähigkeit nicht aus, da sie selbst herbeigeführt "
|
||||
"wurde (actio libera in causa). Zudem haftet A nach § 7 StVG als "
|
||||
"Halter/Fahrer. Vollständiger Schadensersatz 5.000 €."
|
||||
),
|
||||
"flawed": (
|
||||
"A ist wegen Trunkenheit schuldunfähig (§ 827 BGB) und haftet nicht. "
|
||||
"B muss seinen Schaden selbst tragen oder die eigene Versicherung "
|
||||
"bemühen."
|
||||
),
|
||||
"domain": "logic_reasoning",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ORPO Preference Pairs: (preferred=correct, dispreferred=flawed)
|
||||
def build_orpo_pairs(
|
||||
cases: List[Dict[str, str]], tokenizer: Any, max_length: int = 128
|
||||
) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Baut ORPO-Trainingspaare aus den Klausuren-Cases."""
|
||||
dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
|
||||
for case in cases:
|
||||
prompt = (
|
||||
f"Fall: {case['description']}\n"
|
||||
f"Normen: {case['statutes']}\n"
|
||||
"Analysiere im Gutachtenstil (FIRT):\n"
|
||||
)
|
||||
text_win = prompt + case["correct"]
|
||||
text_lose = prompt + case["flawed"]
|
||||
|
||||
enc_win = tokenizer(
|
||||
text_win,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
enc_lose = tokenizer(
|
||||
text_lose,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels_win = enc_win["input_ids"].clone()
|
||||
labels_win[enc_win["attention_mask"] == 0] = -100
|
||||
labels_lose = enc_lose["input_ids"].clone()
|
||||
labels_lose[enc_lose["attention_mask"] == 0] = -100
|
||||
|
||||
dataset.append(
|
||||
(
|
||||
enc_win["input_ids"],
|
||||
labels_win,
|
||||
enc_lose["input_ids"],
|
||||
labels_lose,
|
||||
)
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# STEP 1: HYPERPARAMETER SWEEP
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class HParamConfig:
|
||||
lr: float
|
||||
population_size: int
|
||||
rank_min: int
|
||||
rank_max: int
|
||||
explained_variance: float
|
||||
|
||||
|
||||
def run_hparam_sweep(
|
||||
model_name: str,
|
||||
dataset: List[Tuple[torch.Tensor, ...]],
|
||||
steps: int = 15,
|
||||
) -> HParamConfig:
|
||||
"""Grid Search über kritische Hyperparameter auf Kaggle T4."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
configs: List[HParamConfig] = [
|
||||
HParamConfig(
|
||||
lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95
|
||||
),
|
||||
HParamConfig(
|
||||
lr=5e-5,
|
||||
population_size=16,
|
||||
rank_min=16,
|
||||
rank_max=64,
|
||||
explained_variance=0.95,
|
||||
),
|
||||
HParamConfig(
|
||||
lr=2e-4, population_size=4, rank_min=4, rank_max=16, explained_variance=0.90
|
||||
),
|
||||
HParamConfig(
|
||||
lr=1e-4, population_size=8, rank_min=8, rank_max=64, explained_variance=0.97
|
||||
),
|
||||
HParamConfig(
|
||||
lr=8e-5,
|
||||
population_size=12,
|
||||
rank_min=8,
|
||||
rank_max=32,
|
||||
explained_variance=0.95,
|
||||
),
|
||||
]
|
||||
|
||||
best_config: Optional[HParamConfig] = None
|
||||
best_score = float("inf")
|
||||
results = []
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("STEP 1: HYPERPARAMETER SWEEP")
|
||||
print("=" * 60)
|
||||
|
||||
for i, cfg in enumerate(configs):
|
||||
print(
|
||||
f"\n[Sweep {i+1}/{len(configs)}] lr={cfg.lr}, pop={cfg.population_size}, "
|
||||
f"rank=[{cfg.rank_min},{cfg.rank_max}]"
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
|
||||
|
||||
extractor = ParasiticQLoRAExtractor(
|
||||
QLoRAConfig(
|
||||
min_rank=cfg.rank_min,
|
||||
max_rank=cfg.rank_max,
|
||||
explained_variance_threshold=cfg.explained_variance,
|
||||
extraction_interval=steps,
|
||||
interesting_point_detection=False,
|
||||
)
|
||||
)
|
||||
extractor.snapshot_base(model)
|
||||
|
||||
model.train()
|
||||
losses = []
|
||||
t0 = time.perf_counter()
|
||||
|
||||
for step in range(steps):
|
||||
batch = dataset[step % len(dataset)]
|
||||
input_win, labels_win, input_lose, labels_lose = [
|
||||
t.to(DEVICE) for t in batch
|
||||
]
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(input_win, labels=labels_win)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.append(float(loss.item()))
|
||||
|
||||
wall_time = time.perf_counter() - t0
|
||||
|
||||
# Extract one adapter to measure quality
|
||||
adapter = extractor.extract_adapters(
|
||||
model, steps, {"loss": losses[-1], "optimizer": "AdamW"}
|
||||
)
|
||||
avg_ev = adapter.avg_explained_variance
|
||||
final_loss = losses[-1]
|
||||
score = final_loss * (1.0 - avg_ev) # lower is better
|
||||
|
||||
print(
|
||||
f" -> final_loss={final_loss:.4f} | avg_ev={avg_ev:.3f} | "
|
||||
f"score={score:.4f} | time={wall_time:.1f}s"
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"config": asdict(cfg),
|
||||
"score": score,
|
||||
"final_loss": final_loss,
|
||||
"avg_ev": avg_ev,
|
||||
}
|
||||
)
|
||||
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_config = cfg
|
||||
|
||||
del model
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
out_path = os.path.join(OUTPUT_DIR, "step1_best_config.json")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"best": asdict(best_config), "all_results": results}, f, indent=2) # type: ignore
|
||||
|
||||
print(f"\n✅ Step 1 done. Best: {best_config} → saved to {out_path}")
|
||||
return best_config # type: ignore
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# STEP 2: BUILD ADAPTER LIBRARY ON SMALL MODELS
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def build_adapter_library(
|
||||
model_name: str,
|
||||
best_cfg: HParamConfig,
|
||||
dataset: List[Tuple[torch.Tensor, ...]],
|
||||
steps: int = 50,
|
||||
) -> List[ExpertAdapter]:
|
||||
"""Trainiert Pythia-70m/160m auf Kaggle und baut eine diverse Adapter Library auf."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("STEP 2: BUILD ADAPTER LIBRARY")
|
||||
print("=" * 60)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
||||
aligner = ExpertManifoldAligner(model)
|
||||
|
||||
extractor = ParasiticQLoRAExtractor(
|
||||
QLoRAConfig(
|
||||
min_rank=best_cfg.rank_min,
|
||||
max_rank=best_cfg.rank_max,
|
||||
explained_variance_threshold=best_cfg.explained_variance,
|
||||
extraction_interval=max(5, steps // 10), # Extract ~10 checkpoints
|
||||
interesting_point_detection=True,
|
||||
)
|
||||
)
|
||||
extractor.snapshot_base(model)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr)
|
||||
model.train()
|
||||
|
||||
print(
|
||||
f"Training {model_name} for {steps} steps, extracting adapters every "
|
||||
f"{extractor.config.extraction_interval} steps..."
|
||||
)
|
||||
|
||||
for step in range(steps):
|
||||
batch = dataset[step % len(dataset)]
|
||||
input_win, labels_win, input_lose, labels_lose = [t.to(DEVICE) for t in batch]
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(input_win, labels=labels_win)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
aligner.track_step(model)
|
||||
|
||||
if extractor.should_extract(step, float(loss.item())):
|
||||
metrics: Dict[str, Any] = {
|
||||
"loss": float(loss.item()),
|
||||
"optimizer": "AdamW",
|
||||
}
|
||||
adapter = extractor.extract_adapters(model, step, metrics)
|
||||
aligner.tag_adapter(adapter)
|
||||
profile = aligner.profile_adapter(adapter)
|
||||
print(
|
||||
f" Step {step:3d} | loss={loss.item():.4f} | "
|
||||
f"tags={adapter.domain_tags} | "
|
||||
f"statute={profile['statute_recall']:.2f} "
|
||||
f"logic={profile['logic_reasoning']:.2f} "
|
||||
f"style={profile['style_gutachtenstil']:.2f}"
|
||||
)
|
||||
|
||||
if step % 10 == 0:
|
||||
print(f" Step {step}/{steps} | loss={loss.item():.4f}")
|
||||
|
||||
# Save checkpoint for warm-start (Step 4 will also use this)
|
||||
ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint")
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
model.save_pretrained(ckpt_dir)
|
||||
print(f" Checkpoint saved to {ckpt_dir}")
|
||||
|
||||
# Save adapter library
|
||||
lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt")
|
||||
extractor.save_library(lib_path)
|
||||
print(
|
||||
f"\n✅ Step 2 done. Library: {len(extractor.adapter_library)} adapters → {lib_path}"
|
||||
)
|
||||
|
||||
del model
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return list(extractor.adapter_library)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# STEP 3: PRETRAIN MOE GATE
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def pretrain_moe_gate(
|
||||
adapter_library: List[ExpertAdapter],
|
||||
in_features: int = 512,
|
||||
gate_steps: int = 200,
|
||||
) -> None:
|
||||
"""Trainiert den LearnableGate MLP mit synthetischen Domain-Routing-Labels."""
|
||||
print("\n" + "=" * 60)
|
||||
print("STEP 3: PRETRAIN MoE GATE")
|
||||
print("=" * 60)
|
||||
|
||||
num_adapters = len(adapter_library)
|
||||
if num_adapters == 0:
|
||||
print(" No adapters in library, skipping.")
|
||||
return
|
||||
|
||||
gate = LearnableGate(in_features, num_adapters).to(DEVICE)
|
||||
gate_optimizer = torch.optim.AdamW(gate.parameters(), lr=1e-3)
|
||||
|
||||
# Build synthetic supervision: for each adapter, its dominant domain
|
||||
# determines what input pattern should route to it.
|
||||
# We use one-hot targets based on domain tags.
|
||||
adapter_targets = []
|
||||
for adapter in adapter_library:
|
||||
# Find dominant domain
|
||||
if "statute_recall" in adapter.domain_tags:
|
||||
d = 0
|
||||
elif "logic_reasoning" in adapter.domain_tags:
|
||||
d = 1
|
||||
elif "style_gutachtenstil" in adapter.domain_tags:
|
||||
d = 2
|
||||
else:
|
||||
d = 0
|
||||
adapter_targets.append(d)
|
||||
|
||||
print(f" Training gate for {num_adapters} adapters, {gate_steps} steps...")
|
||||
losses = []
|
||||
|
||||
for step in range(gate_steps):
|
||||
# Synthetic input: each domain has a rough "signature" pattern
|
||||
batch_size = 8
|
||||
domain_inputs = []
|
||||
domain_labels = []
|
||||
|
||||
for _ in range(batch_size):
|
||||
# Pick a random adapter
|
||||
adapter_idx = step % num_adapters
|
||||
dom = adapter_targets[adapter_idx]
|
||||
|
||||
# Domain-specific input pattern (synthetic)
|
||||
x = torch.randn(in_features, device=DEVICE)
|
||||
# Amplify features in the domain's "direction"
|
||||
x[dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0
|
||||
|
||||
domain_inputs.append(x)
|
||||
|
||||
# Target: one-hot over adapters with this adapter having highest weight
|
||||
# We want gate to learn to output high score for adapters of this domain
|
||||
matching_adapters = [i for i, t in enumerate(adapter_targets) if t == dom]
|
||||
target = torch.zeros(num_adapters, device=DEVICE)
|
||||
for idx in matching_adapters:
|
||||
target[idx] = 1.0 / max(1, len(matching_adapters))
|
||||
domain_labels.append(target)
|
||||
|
||||
x_batch = torch.stack(domain_inputs) # [B, in_features]
|
||||
y_batch = torch.stack(domain_labels) # [B, num_adapters]
|
||||
|
||||
gate_optimizer.zero_grad()
|
||||
logits = gate(x_batch) # [B, num_adapters]
|
||||
loss = F.kl_div(
|
||||
F.log_softmax(logits, dim=-1),
|
||||
y_batch,
|
||||
reduction="batchmean",
|
||||
)
|
||||
loss.backward()
|
||||
gate_optimizer.step()
|
||||
losses.append(float(loss.item()))
|
||||
|
||||
if step % 50 == 0:
|
||||
print(f" Gate step {step}/{gate_steps} | loss={loss.item():.4f}")
|
||||
|
||||
gate_path = os.path.join(OUTPUT_DIR, "step3_gate_weights.pt")
|
||||
torch.save(gate.state_dict(), gate_path)
|
||||
final_loss = sum(losses[-20:]) / 20
|
||||
print(f"\n✅ Step 3 done. Gate loss={final_loss:.4f} → {gate_path}")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# STEP 4: FCES WARM-START CHECKPOINT
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def build_fces_warm_start(
|
||||
model_name: str,
|
||||
best_cfg: HParamConfig,
|
||||
dataset: List[Tuple[torch.Tensor, ...]],
|
||||
warm_steps: int = 30,
|
||||
) -> None:
|
||||
"""Führt FCES für warm_steps durch, speichert Checkpoint und Population-Metadaten.
|
||||
|
||||
Da FCESOptimizer keine vollständige Population-Serialisierung hat (TODO in C++),
|
||||
speichern wir:
|
||||
1. Model weights (bestes Checkpoint nach letztem Roll-Forward)
|
||||
2. Aktive Controller-Metadata: ID, fitness, step_count
|
||||
3. Optimale Hyperparameter für RunPod als JSON
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("STEP 4: FCES WARM-START")
|
||||
print("=" * 60)
|
||||
|
||||
# Try to import fces_native; fall back to AdamW warm-start if not available
|
||||
try:
|
||||
import fces_native # noqa: F401
|
||||
|
||||
fces_available = True
|
||||
except ImportError:
|
||||
print(" [WARN] fces_native not compiled. Using AdamW warm-start instead.")
|
||||
fces_available = False
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
||||
model.train()
|
||||
|
||||
controller_log: List[Dict[str, Any]] = []
|
||||
best_loss = float("inf")
|
||||
best_step = 0
|
||||
|
||||
if fces_available:
|
||||
import fces_native
|
||||
|
||||
cfg = fces_native.FCESConfig()
|
||||
cfg.lr = best_cfg.lr
|
||||
cfg.population_size = best_cfg.population_size
|
||||
cfg.total_steps = warm_steps
|
||||
optimizer_fces = fces_native.FCESOptimizer(list(model.parameters()), cfg)
|
||||
|
||||
print(f" Running FCES warm-start for {warm_steps} steps...")
|
||||
for step in range(warm_steps):
|
||||
batch = dataset[step % len(dataset)]
|
||||
input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch]
|
||||
|
||||
optimizer_fces.zero_grad()
|
||||
outputs = model(input_win, labels=labels_win)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer_fces.step()
|
||||
optimizer_fces.update_fitness(float(loss.item()))
|
||||
|
||||
ctrl_id = optimizer_fces.get_active_controller_id()
|
||||
ctrl_fit = optimizer_fces.get_active_controller_fitness()
|
||||
controller_log.append(
|
||||
{
|
||||
"step": step,
|
||||
"loss": float(loss.item()),
|
||||
"controller_id": int(ctrl_id),
|
||||
"controller_fitness": float(ctrl_fit),
|
||||
}
|
||||
)
|
||||
|
||||
if float(loss.item()) < best_loss:
|
||||
best_loss = float(loss.item())
|
||||
best_step = step
|
||||
|
||||
if step % 10 == 0:
|
||||
print(
|
||||
f" Step {step:3d} | loss={loss.item():.4f} | "
|
||||
f"ctrl_id={ctrl_id} | fitness={ctrl_fit:.4f}"
|
||||
)
|
||||
else:
|
||||
# AdamW warm-start as fallback
|
||||
optimizer_adam = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr)
|
||||
for step in range(warm_steps):
|
||||
batch = dataset[step % len(dataset)]
|
||||
input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch]
|
||||
optimizer_adam.zero_grad()
|
||||
outputs = model(input_win, labels=labels_win)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer_adam.step()
|
||||
if float(loss.item()) < best_loss:
|
||||
best_loss = float(loss.item())
|
||||
best_step = step
|
||||
if step % 10 == 0:
|
||||
print(f" Step {step:3d} | loss={loss.item():.4f}")
|
||||
|
||||
# Save warm checkpoint
|
||||
ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint")
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
model.save_pretrained(ckpt_dir)
|
||||
|
||||
# Save warm-start metadata
|
||||
meta = {
|
||||
"model_name": model_name,
|
||||
"warm_steps": warm_steps,
|
||||
"best_loss": best_loss,
|
||||
"best_step": best_step,
|
||||
"fces_available": fces_available,
|
||||
"best_hparams": asdict(best_cfg),
|
||||
"controller_log": controller_log,
|
||||
"runpod_hint": (
|
||||
f"Load from {ckpt_dir}. FCES already explored {warm_steps} steps. "
|
||||
f"Resume from step {warm_steps} with lr={best_cfg.lr}, "
|
||||
f"population_size={best_cfg.population_size}."
|
||||
),
|
||||
}
|
||||
meta_path = os.path.join(OUTPUT_DIR, "step4_warm_metadata.json")
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
print(
|
||||
f"\n✅ Step 4 done. Warm checkpoint (best_loss={best_loss:.4f} @ step {best_step})"
|
||||
)
|
||||
print(f" Checkpoint: {ckpt_dir}")
|
||||
print(f" Metadata: {meta_path}")
|
||||
|
||||
del model
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# STEP 5: TOKENIZE DATASET + ORPO PAIRS
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def build_full_dataset(model_name: str) -> None:
|
||||
"""Tokenisiert alle Klausuren-Cases als ORPO-Pairs und speichert sie als .pt."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("STEP 5: TOKENIZE DATASET + BUILD ORPO PAIRS")
|
||||
print("=" * 60)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Build at multiple sequence lengths for flexible training
|
||||
datasets: Dict[str, Any] = {}
|
||||
for max_len in [64, 128, 256]:
|
||||
pairs = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=max_len)
|
||||
datasets[f"max_len_{max_len}"] = [
|
||||
(iw.squeeze(0), lw.squeeze(0), il.squeeze(0), ll.squeeze(0))
|
||||
for (iw, lw, il, ll) in pairs
|
||||
]
|
||||
print(f" max_len={max_len}: {len(pairs)} pairs built")
|
||||
|
||||
# Save metadata
|
||||
case_meta = [
|
||||
{
|
||||
"id": c["id"],
|
||||
"domain": c["domain"],
|
||||
"statutes": c["statutes"],
|
||||
"description_len": len(c["description"]),
|
||||
}
|
||||
for c in KLAUSUREN_CASES
|
||||
]
|
||||
datasets["metadata"] = {
|
||||
"model_name": model_name,
|
||||
"num_cases": len(KLAUSUREN_CASES),
|
||||
"cases": case_meta,
|
||||
"domains": list({c["domain"] for c in KLAUSUREN_CASES}),
|
||||
"vocab_size": tokenizer.vocab_size,
|
||||
}
|
||||
|
||||
out_path = os.path.join(OUTPUT_DIR, "step5_dataset.pt")
|
||||
torch.save(datasets, out_path)
|
||||
size_mb = os.path.getsize(out_path) / 1024 / 1024
|
||||
print(f"\n✅ Step 5 done. Dataset ({size_mb:.2f} MB) → {out_path}")
|
||||
print(f" Domains: {datasets['metadata']['domains']}")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# MAIN PIPELINE
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(description="Kaggle Preparation Pipeline")
|
||||
parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m")
|
||||
parser.add_argument(
|
||||
"--sweep-steps", type=int, default=15, help="Steps per hyperparameter sweep run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-steps", type=int, default=50, help="Steps for adapter library build"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm-steps", type=int, default=30, help="Steps for FCES warm-start"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-steps",
|
||||
type=str,
|
||||
default="",
|
||||
help="Comma-separated step numbers to skip, e.g. '1,3'",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
skip = {int(s) for s in args.skip_steps.split(",") if s.strip()}
|
||||
t_total = time.perf_counter()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" KAGGLE PREPARATION PIPELINE — FCES-native")
|
||||
print(f" Model: {args.model} | Device: {DEVICE}")
|
||||
print("=" * 70)
|
||||
|
||||
# Build shared tokenized dataset (used across all steps)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
shared_dataset = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=64)
|
||||
|
||||
# ── Step 1: Hyperparameter Sweep ─────────────────────────────────────────
|
||||
if 1 not in skip:
|
||||
best_cfg = run_hparam_sweep(args.model, shared_dataset, steps=args.sweep_steps)
|
||||
else:
|
||||
print("\n[SKIP] Step 1: Using default config")
|
||||
best_cfg = HParamConfig(
|
||||
lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95
|
||||
)
|
||||
|
||||
# ── Step 2: Adapter Library ───────────────────────────────────────────────
|
||||
if 2 not in skip:
|
||||
adapter_library = build_adapter_library(
|
||||
args.model, best_cfg, shared_dataset, steps=args.adapter_steps
|
||||
)
|
||||
else:
|
||||
print("\n[SKIP] Step 2: Loading existing library")
|
||||
lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt")
|
||||
adapter_library = (
|
||||
ParasiticQLoRAExtractor.load_library(lib_path)
|
||||
if os.path.exists(lib_path)
|
||||
else []
|
||||
)
|
||||
|
||||
# ── Step 3: MoE Gate Pre-training ────────────────────────────────────────
|
||||
if 3 not in skip:
|
||||
# Detect in_features from model config
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
in_features = getattr(config, "hidden_size", 512)
|
||||
except Exception:
|
||||
in_features = 512
|
||||
pretrain_moe_gate(adapter_library, in_features=in_features, gate_steps=200)
|
||||
else:
|
||||
print("\n[SKIP] Step 3: Gate pre-training skipped")
|
||||
|
||||
# ── Step 4: FCES Warm-Start ───────────────────────────────────────────────
|
||||
if 4 not in skip:
|
||||
build_fces_warm_start(
|
||||
args.model, best_cfg, shared_dataset, warm_steps=args.warm_steps
|
||||
)
|
||||
else:
|
||||
print("\n[SKIP] Step 4: Warm-start skipped")
|
||||
|
||||
# ── Step 5: Full Dataset Tokenization ─────────────────────────────────────
|
||||
if 5 not in skip:
|
||||
build_full_dataset(args.model)
|
||||
else:
|
||||
print("\n[SKIP] Step 5: Dataset tokenization skipped")
|
||||
|
||||
# ── Summary ───────────────────────────────────────────────────────────────
|
||||
total_time = time.perf_counter() - t_total
|
||||
output_files = [f for f in os.listdir(OUTPUT_DIR) if not f.startswith(".")]
|
||||
total_size_mb = (
|
||||
sum(
|
||||
os.path.getsize(os.path.join(OUTPUT_DIR, f))
|
||||
for f in output_files
|
||||
if os.path.isfile(os.path.join(OUTPUT_DIR, f))
|
||||
)
|
||||
/ 1024
|
||||
/ 1024
|
||||
)
|
||||
|
||||
summary = {
|
||||
"model": args.model,
|
||||
"device": DEVICE,
|
||||
"total_wall_time_s": round(total_time, 1),
|
||||
"output_files": output_files,
|
||||
"total_size_mb": round(total_size_mb, 2),
|
||||
"runpod_savings_hint": (
|
||||
"Upload /output/ as Kaggle Dataset. "
|
||||
"RunPod: load step1_best_config.json for hparams, "
|
||||
"step2_adapter_library.pt for warm adapter library, "
|
||||
"step3_gate_weights.pt for MoE router, "
|
||||
"step4_warm_checkpoint/ as starting model weights, "
|
||||
"step5_dataset.pt for pre-tokenized training data. "
|
||||
"Estimated RunPod time savings: 60-80% vs cold start."
|
||||
),
|
||||
}
|
||||
|
||||
summary_path = os.path.join(OUTPUT_DIR, "prep_summary.json")
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" KAGGLE PREP COMPLETE")
|
||||
print("=" * 70)
|
||||
print(f" Total time : {total_time:.1f}s")
|
||||
print(f" Artifacts : {len(output_files)} files, {total_size_mb:.1f} MB")
|
||||
print(f" Output dir : {OUTPUT_DIR}")
|
||||
print("\n RunPod Ladeanweisung:")
|
||||
print(" kaggle datasets download sven/fces-prep -p /data/")
|
||||
for fname in sorted(output_files):
|
||||
size = os.path.getsize(os.path.join(OUTPUT_DIR, fname)) / 1024
|
||||
if os.path.isfile(os.path.join(OUTPUT_DIR, fname)):
|
||||
print(f" /data/{fname:<35} ({size:.0f} KB)")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user