876 lines
32 KiB
Python
876 lines
32 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Kaggle Preparation Pipeline for FCES-native.
|
|
|
|
Führt alle 5 Vorbereitungsschritte auf Kaggle T4/P100 durch, bevor
|
|
teure RunPod A100-Zeit genutzt wird. Alle Artefakte werden in /kaggle/working/
|
|
gespeichert und können als Kaggle Dataset veröffentlicht werden.
|
|
|
|
Ausführung auf Kaggle:
|
|
!git clone https://git.zky.de/sven/FCES-native
|
|
!cd FCES-native && pip install -r requirements.txt
|
|
!cd FCES-native && python python/setup.py build_ext --inplace
|
|
!cd FCES-native && python kaggle_prep.py
|
|
|
|
Produzierte Artefakte:
|
|
/output/step1_best_config.json ← Beste Hyperparameter
|
|
/output/step2_adapter_library.pt ← Pre-built Expert Adapter Library
|
|
/output/step3_gate_weights.pt ← Vortrainierte MoE Gate Weights
|
|
/output/step4_warm_checkpoint/ ← FCES Warm-Start Checkpoint
|
|
/output/step5_dataset.pt ← Tokenisiertes ORPO-Dataset
|
|
/output/prep_summary.json ← Zusammenfassung aller Schritte
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
# Ensure python/ is importable
|
|
_root = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.insert(0, os.path.join(_root, "python"))
|
|
|
|
from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
|
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
|
|
from adapter_moe_router import LearnableGate # noqa: E402
|
|
|
|
OUTPUT_DIR = "/kaggle/working/output"
|
|
if not os.path.exists("/kaggle"):
|
|
OUTPUT_DIR = os.path.join(_root, "kaggle_output")
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"[Kaggle Prep] Device: {DEVICE} | Output: {OUTPUT_DIR}")
|
|
|
|
|
|
# ==============================================================================
|
|
# LEGAL KLAUSUREN CASES (expanded for Kaggle prep)
|
|
# ==============================================================================
|
|
|
|
KLAUSUREN_CASES: List[Dict[str, str]] = [
|
|
# --- Mietrecht ---
|
|
{
|
|
"id": "miet_01",
|
|
"description": (
|
|
"Mieter M besichtigt im Januar eine Wohnung des Vermieters V. "
|
|
"Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. "
|
|
"M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. "
|
|
"Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. "
|
|
"V verlangt volle Miete."
|
|
),
|
|
"statutes": "§ 535 Abs. 2 BGB, § 536 Abs. 1 BGB, § 536b BGB",
|
|
"correct": (
|
|
"M hat gemäß § 536b BGB das Minderungsrecht verloren, da er bei "
|
|
"Vertragsschluss positive Kenntnis vom Schimmelfleck hatte und "
|
|
"vorbehaltlos unterschrieben hat. V kann volle Miete verlangen."
|
|
),
|
|
"flawed": (
|
|
"M kann die Miete mindern, weil Schimmel ein Sachmangel ist. "
|
|
"§ 536b BGB ist irrelevant, da V eine Zusage gemacht hat."
|
|
),
|
|
"domain": "statute_recall",
|
|
},
|
|
{
|
|
"id": "miet_02",
|
|
"description": (
|
|
"Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft "
|
|
"ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht "
|
|
"erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den "
|
|
"Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich."
|
|
),
|
|
"statutes": "§ 536a Abs. 2 Nr. 2 BGB, § 535 BGB, § 683 BGB",
|
|
"correct": (
|
|
"F hat nach § 536a Abs. 2 Nr. 2 BGB das Selbstbeseitigungsrecht bei "
|
|
"Gefahr im Verzug. V war nicht erreichbar, der Schaden drohend. "
|
|
"F kann 300 € Erstattung verlangen."
|
|
),
|
|
"flawed": (
|
|
"F hätte länger auf V warten müssen. Das Selbstbeseitigungsrecht "
|
|
"besteht nur bei völliger Unerreichbarkeit über mehrere Tage."
|
|
),
|
|
"domain": "logic_reasoning",
|
|
},
|
|
{
|
|
"id": "miet_03",
|
|
"description": (
|
|
"Vermieter V kündigt dem Mieter M ordentlich wegen Eigenbedarfs (§ 573 BGB). "
|
|
"M ist 78 Jahre alt und schwer krank. Er beruft sich auf § 574 BGB "
|
|
"(Sozialklausel). V bestreitet die Härte."
|
|
),
|
|
"statutes": "§ 573 Abs. 2 Nr. 2 BGB, § 574 BGB, § 574a BGB",
|
|
"correct": (
|
|
"Die Kündigung ist formell wirksam (Eigenbedarf). Jedoch kann M gemäß "
|
|
"§ 574 BGB Widerspruch einlegen. Das Gericht wägt ab: Alter 78, schwere "
|
|
"Krankheit begründen in der Regel eine unzumutbare Härte. Kündigung "
|
|
"wird voraussichtlich ausgesetzt oder Fortsetzung angeordnet."
|
|
),
|
|
"flawed": (
|
|
"Eigenbedarf schlägt immer die Sozialklausel. M muss ausziehen. "
|
|
"§ 574 BGB gilt nur bei Vermietern ohne eigentlichen Bedarf."
|
|
),
|
|
"domain": "logic_reasoning",
|
|
},
|
|
# --- Kaufrecht ---
|
|
{
|
|
"id": "kauf_01",
|
|
"description": (
|
|
"K kauft bei V einen Gebrauchtwagen für 8.000 €. Nach 3 Monaten "
|
|
"stellt sich heraus, dass der Kilometerzähler manipuliert war. "
|
|
"Tatsächlich 180.000 km statt 80.000 km. V beruft sich auf "
|
|
"'gekauft wie gesehen' Klausel."
|
|
),
|
|
"statutes": "§ 434 BGB, § 437 BGB, § 444 BGB, § 476 BGB",
|
|
"correct": (
|
|
"Die 'gekauft wie gesehen' Klausel schließt Gewährleistung nicht aus, "
|
|
"wenn arglistige Täuschung vorliegt (§ 444 BGB). Kilometerzähler-"
|
|
"manipulation ist arglistige Täuschung → Klausel unwirksam. "
|
|
"K kann Rücktritt (§ 437 Nr. 2 BGB) oder Schadensersatz verlangen."
|
|
),
|
|
"flawed": (
|
|
"'Gekauft wie gesehen' schließt alle Mängel aus. K hatte Gelegenheit "
|
|
"zur Besichtigung und muss das Fahrzeug so akzeptieren wie es ist."
|
|
),
|
|
"domain": "statute_recall",
|
|
},
|
|
{
|
|
"id": "kauf_02",
|
|
"description": (
|
|
"Unternehmerin U bestellt bei Lieferant L 500 Stühle für ein Büro. "
|
|
"Lieferdatum: 15. März. Am 20. März noch keine Lieferung. U setzt "
|
|
"Nachfrist bis 25. März. L liefert nicht. U kauft Stühle teurer "
|
|
"woanders und fordert Mehrkosten von 3.000 € von L."
|
|
),
|
|
"statutes": "§ 280 Abs. 1, 3 BGB, § 281 BGB, § 323 BGB, § 286 BGB",
|
|
"correct": (
|
|
"L ist durch Überschreiten des Liefertermins in Verzug (§ 286 BGB). "
|
|
"Nach erfolglosem Ablauf der Nachfrist kann U nach § 281 BGB "
|
|
"Schadensersatz statt der Leistung verlangen. Die 3.000 € "
|
|
"Mehrkosten sind als Deckungsschaden erstattungsfähig."
|
|
),
|
|
"flawed": (
|
|
"U muss zunächst Klage erheben und kann erst nach rechtskräftigem Urteil "
|
|
"Ersatz verlangen. Die Nachfrist allein reicht nicht aus."
|
|
),
|
|
"domain": "style_gutachtenstil",
|
|
},
|
|
# --- Deliktsrecht ---
|
|
{
|
|
"id": "delikt_01",
|
|
"description": (
|
|
"A fährt nachts betrunken (1,5 Promille) mit seinem Auto und fährt "
|
|
"das geparkte Auto des B zu. Schaden 5.000 €. A hat keine "
|
|
"Kfz-Haftpflichtversicherung. B verlangt Schadensersatz."
|
|
),
|
|
"statutes": "§ 823 Abs. 1 BGB, § 827 BGB, § 828 BGB, § 7 StVG",
|
|
"correct": (
|
|
"A haftet nach § 823 Abs. 1 BGB (Eigentumsverletzung). Trunkenheit "
|
|
"schließt die Deliktsfähigkeit nicht aus, da sie selbst herbeigeführt "
|
|
"wurde (actio libera in causa). Zudem haftet A nach § 7 StVG als "
|
|
"Halter/Fahrer. Vollständiger Schadensersatz 5.000 €."
|
|
),
|
|
"flawed": (
|
|
"A ist wegen Trunkenheit schuldunfähig (§ 827 BGB) und haftet nicht. "
|
|
"B muss seinen Schaden selbst tragen oder die eigene Versicherung "
|
|
"bemühen."
|
|
),
|
|
"domain": "logic_reasoning",
|
|
},
|
|
]
|
|
|
|
|
|
# ORPO Preference Pairs: (preferred=correct, dispreferred=flawed)
|
|
def build_orpo_pairs(
|
|
cases: List[Dict[str, str]], tokenizer: Any, max_length: int = 128
|
|
) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
"""Baut ORPO-Trainingspaare aus den Klausuren-Cases."""
|
|
dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
|
|
for case in cases:
|
|
prompt = (
|
|
f"Fall: {case['description']}\n"
|
|
f"Normen: {case['statutes']}\n"
|
|
"Analysiere im Gutachtenstil (FIRT):\n"
|
|
)
|
|
text_win = prompt + case["correct"]
|
|
text_lose = prompt + case["flawed"]
|
|
|
|
enc_win = tokenizer(
|
|
text_win,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
enc_lose = tokenizer(
|
|
text_lose,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
labels_win = enc_win["input_ids"].clone()
|
|
labels_win[enc_win["attention_mask"] == 0] = -100
|
|
labels_lose = enc_lose["input_ids"].clone()
|
|
labels_lose[enc_lose["attention_mask"] == 0] = -100
|
|
|
|
dataset.append(
|
|
(
|
|
enc_win["input_ids"],
|
|
labels_win,
|
|
enc_lose["input_ids"],
|
|
labels_lose,
|
|
)
|
|
)
|
|
return dataset
|
|
|
|
|
|
# ==============================================================================
|
|
# STEP 1: HYPERPARAMETER SWEEP
|
|
# ==============================================================================
|
|
|
|
|
|
@dataclass
|
|
class HParamConfig:
|
|
lr: float
|
|
population_size: int
|
|
rank_min: int
|
|
rank_max: int
|
|
explained_variance: float
|
|
|
|
|
|
def run_hparam_sweep(
|
|
model_name: str,
|
|
dataset: List[Tuple[torch.Tensor, ...]],
|
|
steps: int = 15,
|
|
) -> HParamConfig:
|
|
"""Grid Search über kritische Hyperparameter auf Kaggle T4."""
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
configs: List[HParamConfig] = [
|
|
HParamConfig(
|
|
lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95
|
|
),
|
|
HParamConfig(
|
|
lr=5e-5,
|
|
population_size=16,
|
|
rank_min=16,
|
|
rank_max=64,
|
|
explained_variance=0.95,
|
|
),
|
|
HParamConfig(
|
|
lr=2e-4, population_size=4, rank_min=4, rank_max=16, explained_variance=0.90
|
|
),
|
|
HParamConfig(
|
|
lr=1e-4, population_size=8, rank_min=8, rank_max=64, explained_variance=0.97
|
|
),
|
|
HParamConfig(
|
|
lr=8e-5,
|
|
population_size=12,
|
|
rank_min=8,
|
|
rank_max=32,
|
|
explained_variance=0.95,
|
|
),
|
|
]
|
|
|
|
best_config: Optional[HParamConfig] = None
|
|
best_score = float("inf")
|
|
results = []
|
|
|
|
print("\n" + "=" * 60)
|
|
print("STEP 1: HYPERPARAMETER SWEEP")
|
|
print("=" * 60)
|
|
|
|
for i, cfg in enumerate(configs):
|
|
print(
|
|
f"\n[Sweep {i+1}/{len(configs)}] lr={cfg.lr}, pop={cfg.population_size}, "
|
|
f"rank=[{cfg.rank_min},{cfg.rank_max}]"
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
|
|
|
|
extractor = ParasiticQLoRAExtractor(
|
|
QLoRAConfig(
|
|
min_rank=cfg.rank_min,
|
|
max_rank=cfg.rank_max,
|
|
explained_variance_threshold=cfg.explained_variance,
|
|
extraction_interval=steps,
|
|
interesting_point_detection=False,
|
|
)
|
|
)
|
|
extractor.snapshot_base(model)
|
|
|
|
model.train()
|
|
losses = []
|
|
t0 = time.perf_counter()
|
|
|
|
for step in range(steps):
|
|
batch = dataset[step % len(dataset)]
|
|
input_win, labels_win, input_lose, labels_lose = [
|
|
t.to(DEVICE) for t in batch
|
|
]
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(input_win, labels=labels_win)
|
|
loss = outputs.loss
|
|
loss.backward()
|
|
optimizer.step()
|
|
losses.append(float(loss.item()))
|
|
|
|
wall_time = time.perf_counter() - t0
|
|
|
|
# Extract one adapter to measure quality
|
|
adapter = extractor.extract_adapters(
|
|
model, steps, {"loss": losses[-1], "optimizer": "AdamW"}
|
|
)
|
|
avg_ev = adapter.avg_explained_variance
|
|
final_loss = losses[-1]
|
|
score = final_loss * (1.0 - avg_ev) # lower is better
|
|
|
|
print(
|
|
f" -> final_loss={final_loss:.4f} | avg_ev={avg_ev:.3f} | "
|
|
f"score={score:.4f} | time={wall_time:.1f}s"
|
|
)
|
|
results.append(
|
|
{
|
|
"config": asdict(cfg),
|
|
"score": score,
|
|
"final_loss": final_loss,
|
|
"avg_ev": avg_ev,
|
|
}
|
|
)
|
|
|
|
if score < best_score:
|
|
best_score = score
|
|
best_config = cfg
|
|
|
|
del model
|
|
if DEVICE == "cuda":
|
|
torch.cuda.empty_cache()
|
|
|
|
out_path = os.path.join(OUTPUT_DIR, "step1_best_config.json")
|
|
with open(out_path, "w", encoding="utf-8") as f:
|
|
json.dump({"best": asdict(best_config), "all_results": results}, f, indent=2) # type: ignore
|
|
|
|
print(f"\n✅ Step 1 done. Best: {best_config} → saved to {out_path}")
|
|
return best_config # type: ignore
|
|
|
|
|
|
# ==============================================================================
|
|
# STEP 2: BUILD ADAPTER LIBRARY ON SMALL MODELS
|
|
# ==============================================================================
|
|
|
|
|
|
def build_adapter_library(
|
|
model_name: str,
|
|
best_cfg: HParamConfig,
|
|
dataset: List[Tuple[torch.Tensor, ...]],
|
|
steps: int = 50,
|
|
) -> List[ExpertAdapter]:
|
|
"""Trainiert Pythia-70m/160m auf Kaggle und baut eine diverse Adapter Library auf."""
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
print("\n" + "=" * 60)
|
|
print("STEP 2: BUILD ADAPTER LIBRARY")
|
|
print("=" * 60)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
|
aligner = ExpertManifoldAligner(model)
|
|
|
|
extractor = ParasiticQLoRAExtractor(
|
|
QLoRAConfig(
|
|
min_rank=best_cfg.rank_min,
|
|
max_rank=best_cfg.rank_max,
|
|
explained_variance_threshold=best_cfg.explained_variance,
|
|
extraction_interval=max(5, steps // 10), # Extract ~10 checkpoints
|
|
interesting_point_detection=True,
|
|
)
|
|
)
|
|
extractor.snapshot_base(model)
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr)
|
|
model.train()
|
|
|
|
print(
|
|
f"Training {model_name} for {steps} steps, extracting adapters every "
|
|
f"{extractor.config.extraction_interval} steps..."
|
|
)
|
|
|
|
for step in range(steps):
|
|
batch = dataset[step % len(dataset)]
|
|
input_win, labels_win, input_lose, labels_lose = [t.to(DEVICE) for t in batch]
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(input_win, labels=labels_win)
|
|
loss = outputs.loss
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
aligner.track_step(model)
|
|
|
|
if extractor.should_extract(step, float(loss.item())):
|
|
metrics: Dict[str, Any] = {
|
|
"loss": float(loss.item()),
|
|
"optimizer": "AdamW",
|
|
}
|
|
adapter = extractor.extract_adapters(model, step, metrics)
|
|
aligner.tag_adapter(adapter)
|
|
profile = aligner.profile_adapter(adapter)
|
|
print(
|
|
f" Step {step:3d} | loss={loss.item():.4f} | "
|
|
f"tags={adapter.domain_tags} | "
|
|
f"statute={profile['statute_recall']:.2f} "
|
|
f"logic={profile['logic_reasoning']:.2f} "
|
|
f"style={profile['style_gutachtenstil']:.2f}"
|
|
)
|
|
|
|
if step % 10 == 0:
|
|
print(f" Step {step}/{steps} | loss={loss.item():.4f}")
|
|
|
|
# Save checkpoint for warm-start (Step 4 will also use this)
|
|
ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint")
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
model.save_pretrained(ckpt_dir)
|
|
print(f" Checkpoint saved to {ckpt_dir}")
|
|
|
|
# Save adapter library
|
|
lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt")
|
|
extractor.save_library(lib_path)
|
|
print(
|
|
f"\n✅ Step 2 done. Library: {len(extractor.adapter_library)} adapters → {lib_path}"
|
|
)
|
|
|
|
del model
|
|
if DEVICE == "cuda":
|
|
torch.cuda.empty_cache()
|
|
|
|
return list(extractor.adapter_library)
|
|
|
|
|
|
# ==============================================================================
|
|
# STEP 3: PRETRAIN MOE GATE
|
|
# ==============================================================================
|
|
|
|
|
|
def pretrain_moe_gate(
|
|
adapter_library: List[ExpertAdapter],
|
|
in_features: int = 512,
|
|
gate_steps: int = 200,
|
|
) -> None:
|
|
"""Trainiert den LearnableGate MLP mit synthetischen Domain-Routing-Labels."""
|
|
print("\n" + "=" * 60)
|
|
print("STEP 3: PRETRAIN MoE GATE")
|
|
print("=" * 60)
|
|
|
|
num_adapters = len(adapter_library)
|
|
if num_adapters == 0:
|
|
print(" No adapters in library, skipping.")
|
|
return
|
|
|
|
gate = LearnableGate(in_features, num_adapters).to(DEVICE)
|
|
gate_optimizer = torch.optim.AdamW(gate.parameters(), lr=1e-3)
|
|
|
|
# Build synthetic supervision: for each adapter, its dominant domain
|
|
# determines what input pattern should route to it.
|
|
# We use one-hot targets based on domain tags.
|
|
adapter_targets = []
|
|
for adapter in adapter_library:
|
|
# Find dominant domain
|
|
if "statute_recall" in adapter.domain_tags:
|
|
d = 0
|
|
elif "logic_reasoning" in adapter.domain_tags:
|
|
d = 1
|
|
elif "style_gutachtenstil" in adapter.domain_tags:
|
|
d = 2
|
|
else:
|
|
d = 0
|
|
adapter_targets.append(d)
|
|
|
|
print(f" Training gate for {num_adapters} adapters, {gate_steps} steps...")
|
|
losses = []
|
|
|
|
for step in range(gate_steps):
|
|
# Synthetic input: each domain has a rough "signature" pattern
|
|
batch_size = 8
|
|
domain_inputs = []
|
|
domain_labels = []
|
|
|
|
for _ in range(batch_size):
|
|
# Pick a random adapter
|
|
adapter_idx = step % num_adapters
|
|
dom = adapter_targets[adapter_idx]
|
|
|
|
# Domain-specific input pattern (synthetic)
|
|
x = torch.randn(in_features, device=DEVICE)
|
|
# Amplify features in the domain's "direction"
|
|
x[dom * (in_features // 3) : (dom + 1) * (in_features // 3)] += 2.0
|
|
|
|
domain_inputs.append(x)
|
|
|
|
# Target: one-hot over adapters with this adapter having highest weight
|
|
# We want gate to learn to output high score for adapters of this domain
|
|
matching_adapters = [i for i, t in enumerate(adapter_targets) if t == dom]
|
|
target = torch.zeros(num_adapters, device=DEVICE)
|
|
for idx in matching_adapters:
|
|
target[idx] = 1.0 / max(1, len(matching_adapters))
|
|
domain_labels.append(target)
|
|
|
|
x_batch = torch.stack(domain_inputs) # [B, in_features]
|
|
y_batch = torch.stack(domain_labels) # [B, num_adapters]
|
|
|
|
gate_optimizer.zero_grad()
|
|
logits = gate(x_batch) # [B, num_adapters]
|
|
loss = F.kl_div(
|
|
F.log_softmax(logits, dim=-1),
|
|
y_batch,
|
|
reduction="batchmean",
|
|
)
|
|
loss.backward()
|
|
gate_optimizer.step()
|
|
losses.append(float(loss.item()))
|
|
|
|
if step % 50 == 0:
|
|
print(f" Gate step {step}/{gate_steps} | loss={loss.item():.4f}")
|
|
|
|
gate_path = os.path.join(OUTPUT_DIR, "step3_gate_weights.pt")
|
|
torch.save(gate.state_dict(), gate_path)
|
|
final_loss = sum(losses[-20:]) / 20
|
|
print(f"\n✅ Step 3 done. Gate loss={final_loss:.4f} → {gate_path}")
|
|
|
|
|
|
# ==============================================================================
|
|
# STEP 4: FCES WARM-START CHECKPOINT
|
|
# ==============================================================================
|
|
|
|
|
|
def build_fces_warm_start(
|
|
model_name: str,
|
|
best_cfg: HParamConfig,
|
|
dataset: List[Tuple[torch.Tensor, ...]],
|
|
warm_steps: int = 30,
|
|
) -> None:
|
|
"""Führt FCES für warm_steps durch, speichert Checkpoint und Population-Metadaten.
|
|
|
|
Da FCESOptimizer keine vollständige Population-Serialisierung hat (TODO in C++),
|
|
speichern wir:
|
|
1. Model weights (bestes Checkpoint nach letztem Roll-Forward)
|
|
2. Aktive Controller-Metadata: ID, fitness, step_count
|
|
3. Optimale Hyperparameter für RunPod als JSON
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("STEP 4: FCES WARM-START")
|
|
print("=" * 60)
|
|
|
|
# Try to import fces_native; fall back to AdamW warm-start if not available
|
|
try:
|
|
import fces_native # noqa: F401
|
|
|
|
fces_available = True
|
|
except ImportError:
|
|
print(" [WARN] fces_native not compiled. Using AdamW warm-start instead.")
|
|
fces_available = False
|
|
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
|
|
model.train()
|
|
|
|
controller_log: List[Dict[str, Any]] = []
|
|
best_loss = float("inf")
|
|
best_step = 0
|
|
|
|
if fces_available:
|
|
import fces_native
|
|
|
|
cfg = fces_native.FCESConfig()
|
|
cfg.lr = best_cfg.lr
|
|
cfg.population_size = best_cfg.population_size
|
|
cfg.total_steps = warm_steps
|
|
optimizer_fces = fces_native.FCESOptimizer(list(model.parameters()), cfg)
|
|
|
|
print(f" Running FCES warm-start for {warm_steps} steps...")
|
|
for step in range(warm_steps):
|
|
batch = dataset[step % len(dataset)]
|
|
input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch]
|
|
|
|
optimizer_fces.zero_grad()
|
|
outputs = model(input_win, labels=labels_win)
|
|
loss = outputs.loss
|
|
loss.backward()
|
|
optimizer_fces.step()
|
|
optimizer_fces.update_fitness(float(loss.item()))
|
|
|
|
ctrl_id = optimizer_fces.get_active_controller_id()
|
|
ctrl_fit = optimizer_fces.get_active_controller_fitness()
|
|
controller_log.append(
|
|
{
|
|
"step": step,
|
|
"loss": float(loss.item()),
|
|
"controller_id": int(ctrl_id),
|
|
"controller_fitness": float(ctrl_fit),
|
|
}
|
|
)
|
|
|
|
if float(loss.item()) < best_loss:
|
|
best_loss = float(loss.item())
|
|
best_step = step
|
|
|
|
if step % 10 == 0:
|
|
print(
|
|
f" Step {step:3d} | loss={loss.item():.4f} | "
|
|
f"ctrl_id={ctrl_id} | fitness={ctrl_fit:.4f}"
|
|
)
|
|
else:
|
|
# AdamW warm-start as fallback
|
|
optimizer_adam = torch.optim.AdamW(model.parameters(), lr=best_cfg.lr)
|
|
for step in range(warm_steps):
|
|
batch = dataset[step % len(dataset)]
|
|
input_win, labels_win, _, _ = [t.to(DEVICE) for t in batch]
|
|
optimizer_adam.zero_grad()
|
|
outputs = model(input_win, labels=labels_win)
|
|
loss = outputs.loss
|
|
loss.backward()
|
|
optimizer_adam.step()
|
|
if float(loss.item()) < best_loss:
|
|
best_loss = float(loss.item())
|
|
best_step = step
|
|
if step % 10 == 0:
|
|
print(f" Step {step:3d} | loss={loss.item():.4f}")
|
|
|
|
# Save warm checkpoint
|
|
ckpt_dir = os.path.join(OUTPUT_DIR, "step4_warm_checkpoint")
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
model.save_pretrained(ckpt_dir)
|
|
|
|
# Save warm-start metadata
|
|
meta = {
|
|
"model_name": model_name,
|
|
"warm_steps": warm_steps,
|
|
"best_loss": best_loss,
|
|
"best_step": best_step,
|
|
"fces_available": fces_available,
|
|
"best_hparams": asdict(best_cfg),
|
|
"controller_log": controller_log,
|
|
"runpod_hint": (
|
|
f"Load from {ckpt_dir}. FCES already explored {warm_steps} steps. "
|
|
f"Resume from step {warm_steps} with lr={best_cfg.lr}, "
|
|
f"population_size={best_cfg.population_size}."
|
|
),
|
|
}
|
|
meta_path = os.path.join(OUTPUT_DIR, "step4_warm_metadata.json")
|
|
with open(meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta, f, indent=2)
|
|
|
|
print(
|
|
f"\n✅ Step 4 done. Warm checkpoint (best_loss={best_loss:.4f} @ step {best_step})"
|
|
)
|
|
print(f" Checkpoint: {ckpt_dir}")
|
|
print(f" Metadata: {meta_path}")
|
|
|
|
del model
|
|
if DEVICE == "cuda":
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# ==============================================================================
|
|
# STEP 5: TOKENIZE DATASET + ORPO PAIRS
|
|
# ==============================================================================
|
|
|
|
|
|
def build_full_dataset(model_name: str) -> None:
|
|
"""Tokenisiert alle Klausuren-Cases als ORPO-Pairs und speichert sie als .pt."""
|
|
from transformers import AutoTokenizer
|
|
|
|
print("\n" + "=" * 60)
|
|
print("STEP 5: TOKENIZE DATASET + BUILD ORPO PAIRS")
|
|
print("=" * 60)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# Build at multiple sequence lengths for flexible training
|
|
datasets: Dict[str, Any] = {}
|
|
for max_len in [64, 128, 256]:
|
|
pairs = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=max_len)
|
|
datasets[f"max_len_{max_len}"] = [
|
|
(iw.squeeze(0), lw.squeeze(0), il.squeeze(0), ll.squeeze(0))
|
|
for (iw, lw, il, ll) in pairs
|
|
]
|
|
print(f" max_len={max_len}: {len(pairs)} pairs built")
|
|
|
|
# Save metadata
|
|
case_meta = [
|
|
{
|
|
"id": c["id"],
|
|
"domain": c["domain"],
|
|
"statutes": c["statutes"],
|
|
"description_len": len(c["description"]),
|
|
}
|
|
for c in KLAUSUREN_CASES
|
|
]
|
|
datasets["metadata"] = {
|
|
"model_name": model_name,
|
|
"num_cases": len(KLAUSUREN_CASES),
|
|
"cases": case_meta,
|
|
"domains": list({c["domain"] for c in KLAUSUREN_CASES}),
|
|
"vocab_size": tokenizer.vocab_size,
|
|
}
|
|
|
|
out_path = os.path.join(OUTPUT_DIR, "step5_dataset.pt")
|
|
torch.save(datasets, out_path)
|
|
size_mb = os.path.getsize(out_path) / 1024 / 1024
|
|
print(f"\n✅ Step 5 done. Dataset ({size_mb:.2f} MB) → {out_path}")
|
|
print(f" Domains: {datasets['metadata']['domains']}")
|
|
|
|
|
|
# ==============================================================================
|
|
# MAIN PIPELINE
|
|
# ==============================================================================
|
|
|
|
|
|
def main() -> None:
|
|
import argparse
|
|
from transformers import AutoTokenizer
|
|
|
|
parser = argparse.ArgumentParser(description="Kaggle Preparation Pipeline")
|
|
parser.add_argument("--model", type=str, default="EleutherAI/pythia-70m")
|
|
parser.add_argument(
|
|
"--sweep-steps", type=int, default=15, help="Steps per hyperparameter sweep run"
|
|
)
|
|
parser.add_argument(
|
|
"--adapter-steps", type=int, default=50, help="Steps for adapter library build"
|
|
)
|
|
parser.add_argument(
|
|
"--warm-steps", type=int, default=30, help="Steps for FCES warm-start"
|
|
)
|
|
parser.add_argument(
|
|
"--skip-steps",
|
|
type=str,
|
|
default="",
|
|
help="Comma-separated step numbers to skip, e.g. '1,3'",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
skip = {int(s) for s in args.skip_steps.split(",") if s.strip()}
|
|
t_total = time.perf_counter()
|
|
|
|
print("\n" + "=" * 70)
|
|
print(" KAGGLE PREPARATION PIPELINE — FCES-native")
|
|
print(f" Model: {args.model} | Device: {DEVICE}")
|
|
print("=" * 70)
|
|
|
|
# Build shared tokenized dataset (used across all steps)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
shared_dataset = build_orpo_pairs(KLAUSUREN_CASES, tokenizer, max_length=64)
|
|
|
|
# ── Step 1: Hyperparameter Sweep ─────────────────────────────────────────
|
|
if 1 not in skip:
|
|
best_cfg = run_hparam_sweep(args.model, shared_dataset, steps=args.sweep_steps)
|
|
else:
|
|
print("\n[SKIP] Step 1: Using default config")
|
|
best_cfg = HParamConfig(
|
|
lr=1e-4, population_size=8, rank_min=8, rank_max=32, explained_variance=0.95
|
|
)
|
|
|
|
# ── Step 2: Adapter Library ───────────────────────────────────────────────
|
|
if 2 not in skip:
|
|
adapter_library = build_adapter_library(
|
|
args.model, best_cfg, shared_dataset, steps=args.adapter_steps
|
|
)
|
|
else:
|
|
print("\n[SKIP] Step 2: Loading existing library")
|
|
lib_path = os.path.join(OUTPUT_DIR, "step2_adapter_library.pt")
|
|
adapter_library = (
|
|
ParasiticQLoRAExtractor.load_library(lib_path)
|
|
if os.path.exists(lib_path)
|
|
else []
|
|
)
|
|
|
|
# ── Step 3: MoE Gate Pre-training ────────────────────────────────────────
|
|
if 3 not in skip:
|
|
# Detect in_features from model config
|
|
try:
|
|
from transformers import AutoConfig
|
|
|
|
config = AutoConfig.from_pretrained(args.model)
|
|
in_features = getattr(config, "hidden_size", 512)
|
|
except Exception:
|
|
in_features = 512
|
|
pretrain_moe_gate(adapter_library, in_features=in_features, gate_steps=200)
|
|
else:
|
|
print("\n[SKIP] Step 3: Gate pre-training skipped")
|
|
|
|
# ── Step 4: FCES Warm-Start ───────────────────────────────────────────────
|
|
if 4 not in skip:
|
|
build_fces_warm_start(
|
|
args.model, best_cfg, shared_dataset, warm_steps=args.warm_steps
|
|
)
|
|
else:
|
|
print("\n[SKIP] Step 4: Warm-start skipped")
|
|
|
|
# ── Step 5: Full Dataset Tokenization ─────────────────────────────────────
|
|
if 5 not in skip:
|
|
build_full_dataset(args.model)
|
|
else:
|
|
print("\n[SKIP] Step 5: Dataset tokenization skipped")
|
|
|
|
# ── Summary ───────────────────────────────────────────────────────────────
|
|
total_time = time.perf_counter() - t_total
|
|
output_files = [f for f in os.listdir(OUTPUT_DIR) if not f.startswith(".")]
|
|
total_size_mb = (
|
|
sum(
|
|
os.path.getsize(os.path.join(OUTPUT_DIR, f))
|
|
for f in output_files
|
|
if os.path.isfile(os.path.join(OUTPUT_DIR, f))
|
|
)
|
|
/ 1024
|
|
/ 1024
|
|
)
|
|
|
|
summary = {
|
|
"model": args.model,
|
|
"device": DEVICE,
|
|
"total_wall_time_s": round(total_time, 1),
|
|
"output_files": output_files,
|
|
"total_size_mb": round(total_size_mb, 2),
|
|
"runpod_savings_hint": (
|
|
"Upload /output/ as Kaggle Dataset. "
|
|
"RunPod: load step1_best_config.json for hparams, "
|
|
"step2_adapter_library.pt for warm adapter library, "
|
|
"step3_gate_weights.pt for MoE router, "
|
|
"step4_warm_checkpoint/ as starting model weights, "
|
|
"step5_dataset.pt for pre-tokenized training data. "
|
|
"Estimated RunPod time savings: 60-80% vs cold start."
|
|
),
|
|
}
|
|
|
|
summary_path = os.path.join(OUTPUT_DIR, "prep_summary.json")
|
|
with open(summary_path, "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
print("\n" + "=" * 70)
|
|
print(" KAGGLE PREP COMPLETE")
|
|
print("=" * 70)
|
|
print(f" Total time : {total_time:.1f}s")
|
|
print(f" Artifacts : {len(output_files)} files, {total_size_mb:.1f} MB")
|
|
print(f" Output dir : {OUTPUT_DIR}")
|
|
print("\n RunPod Ladeanweisung:")
|
|
print(" kaggle datasets download sven/fces-prep -p /data/")
|
|
for fname in sorted(output_files):
|
|
size = os.path.getsize(os.path.join(OUTPUT_DIR, fname)) / 1024
|
|
if os.path.isfile(os.path.join(OUTPUT_DIR, fname)):
|
|
print(f" /data/{fname:<35} ({size:.0f} KB)")
|
|
print("=" * 70)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|