feat: add comparative FIRT-ORPO benchmark runner for AdamW vs FCES
This commit is contained in:
407
benchmark_fces_vs_adam.py
Normal file
407
benchmark_fces_vs_adam.py
Normal file
@@ -0,0 +1,407 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
# Ensure python folder is in path
|
||||
python_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "python")
|
||||
sys.path.insert(0, python_dir)
|
||||
|
||||
import torch # noqa: E402
|
||||
import fces_native # noqa: E402
|
||||
import dspy # noqa: E402
|
||||
import torch.nn.functional as F # noqa: E402
|
||||
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
||||
|
||||
# ==============================================================================
|
||||
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LegalFIRT(dspy.Signature): # type: ignore[misc]
|
||||
"""Flaw Induced Reasoning Traces (FIRT) for German & EU Tenancy / M&A Law."""
|
||||
|
||||
case_description: str = dspy.InputField(
|
||||
desc="Facts of the German/EU legal case with potential hidden flaws or contradictions."
|
||||
)
|
||||
relevant_statutes: str = dspy.InputField(
|
||||
desc="Relevant articles from the German Civil Code (BGB) or EU Directives."
|
||||
)
|
||||
t0_draft: str = dspy.OutputField(
|
||||
desc="Initial legal assessment or draft resolution based on first-pass reasoning."
|
||||
)
|
||||
t1_flaws: str = dspy.OutputField(
|
||||
desc="Identified legal flaws, logical bottlenecks, or contradictions (e.g. § 536b BGB exclusions)."
|
||||
)
|
||||
t2_refined_opinion: str = dspy.OutputField(
|
||||
desc="Final, correct, and authoritative legal opinion addressing all identified flaws."
|
||||
)
|
||||
|
||||
|
||||
# Legal Cases for Pre- and Post-Training Benchmarking
|
||||
eval_cases: List[Dict[str, str]] = [
|
||||
{
|
||||
"description": "Mieter M besichtigt im Januar eine Wohnung des Vermieters V. Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. V verlangt volle Miete.",
|
||||
"statutes": "§ 535 Abs. 2 BGB (Mietzahlung), § 536 Abs. 1 BGB (Mietminderung bei Sachmangel), § 536b BGB (Kenntnis des Mieters bei Vertragsschluss).",
|
||||
},
|
||||
{
|
||||
"description": "Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich.",
|
||||
"statutes": "§ 536a Abs. 2 Nr. 2 BGB (Selbstbeseitigungsrecht bei Gefahr im Verzug), § 535 BGB.",
|
||||
},
|
||||
]
|
||||
|
||||
# ==============================================================================
|
||||
# 2. ORPO LOSS WITH SOFT LABELS
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def compute_orpo_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
logits_lose: torch.Tensor,
|
||||
labels_lose: torch.Tensor,
|
||||
lambda_orpo: float = 0.1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Computes the combined Cross-Entropy (SFT) and Odds Ratio Preference (ORPO) Loss
|
||||
|
||||
using soft probability log-likelihoods.
|
||||
"""
|
||||
# 1. SFT Loss (Cross Entropy) on the preferred target sequence
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
sft_loss = F.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
# 2. Average log probabilities for preferred (win) and dispreferred (lose) sequences
|
||||
# Logits win log_softmax
|
||||
log_probs_win = F.log_softmax(logits, dim=-1)
|
||||
# Gather actual label log_probs
|
||||
loss_mask_win = labels != -100
|
||||
labels_clamped = labels.clone()
|
||||
labels_clamped[~loss_mask_win] = 0
|
||||
seq_log_probs_win = torch.gather(
|
||||
log_probs_win, dim=-1, index=labels_clamped.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
seq_log_probs_win = (seq_log_probs_win * loss_mask_win).sum(
|
||||
dim=-1
|
||||
) / loss_mask_win.sum(dim=-1).clamp(min=1)
|
||||
|
||||
# Logits lose log_softmax
|
||||
log_probs_lose = F.log_softmax(logits_lose, dim=-1)
|
||||
loss_mask_lose = labels_lose != -100
|
||||
labels_lose_clamped = labels_lose.clone()
|
||||
labels_lose_clamped[~loss_mask_lose] = 0
|
||||
seq_log_probs_lose = torch.gather(
|
||||
log_probs_lose, dim=-1, index=labels_lose_clamped.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
seq_log_probs_lose = (seq_log_probs_lose * loss_mask_lose).sum(
|
||||
dim=-1
|
||||
) / loss_mask_lose.sum(dim=-1).clamp(min=1)
|
||||
|
||||
# 3. Odds Ratio Calculation
|
||||
log_odds_win = seq_log_probs_win - torch.log1p(-torch.exp(seq_log_probs_win))
|
||||
log_odds_lose = seq_log_probs_lose - torch.log1p(-torch.exp(seq_log_probs_lose))
|
||||
|
||||
odds_ratio = log_odds_win - log_odds_lose
|
||||
orpo_loss = -F.logsigmoid(odds_ratio).mean()
|
||||
|
||||
# Total combined loss
|
||||
total_loss = sft_loss + lambda_orpo * orpo_loss
|
||||
return total_loss, sft_loss
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 3. EVALUATION RUNNER
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: str
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Evaluates the model on legal exam cases (zero-shot) to measure T2 opinion quality."""
|
||||
model.eval()
|
||||
results = []
|
||||
for idx, case in enumerate(eval_cases):
|
||||
prompt = (
|
||||
f"Case: {case['description']}\nStatutes: {case['statutes']}\n"
|
||||
"Analyze and resolve using Flaw Induced Reasoning Traces (FIRT):\n"
|
||||
"T0 Draft:\n"
|
||||
)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs, max_new_tokens=256, temperature=0.1, do_sample=False
|
||||
)
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
results.append({"case_id": idx + 1, "output": generated_text})
|
||||
return results
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 4. TRAINING PIPELINE (ADAMW VS FCES)
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def train_run(
|
||||
optimizer_name: str,
|
||||
model_name: str,
|
||||
steps: int,
|
||||
device: str,
|
||||
dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
"""Runs the training loop and performs pre/post evaluations."""
|
||||
print(f"\n[START] Initializing training run with {optimizer_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
||||
|
||||
# 1. Pre-Training Evaluation
|
||||
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
||||
pre_eval = evaluate_model(model, tokenizer, device)
|
||||
|
||||
# 2. Setup Optimizer
|
||||
model.train()
|
||||
if optimizer_name == "AdamW":
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
elif optimizer_name == "FCES":
|
||||
cfg = fces_native.FCESConfig()
|
||||
cfg.lr = 1e-4
|
||||
cfg.population_size = 8 # Balanced for resource footprint
|
||||
cfg.total_steps = steps
|
||||
optimizer = fces_native.FCESOptimizer(list(model.parameters()), cfg)
|
||||
else:
|
||||
raise ValueError("Invalid optimizer name")
|
||||
|
||||
# 3. Training loop
|
||||
logs = []
|
||||
start_time = time.perf_counter()
|
||||
tokens_processed = 0
|
||||
|
||||
for step in range(steps):
|
||||
# Cyclically fetch batches from pre-tokenized dataset
|
||||
batch = dataset[step % len(dataset)]
|
||||
input_win, labels_win, input_lose, labels_lose = [t.to(device) for t in batch]
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass for win (preferred)
|
||||
outputs_win = model(input_win, labels=labels_win)
|
||||
logits_win = outputs_win.logits
|
||||
|
||||
# Forward pass for lose (dispreferred)
|
||||
with torch.no_grad():
|
||||
outputs_lose = model(input_lose)
|
||||
logits_lose = outputs_lose.logits
|
||||
|
||||
# Compute combined ORPO loss
|
||||
loss, sft_loss = compute_orpo_loss(
|
||||
logits_win, labels_win, logits_lose, labels_lose
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if optimizer_name == "FCES":
|
||||
optimizer.update_fitness(float(loss.item()))
|
||||
|
||||
# Tracking metrics
|
||||
elapsed = time.perf_counter() - start_time
|
||||
batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item())
|
||||
tokens_processed += batch_tokens
|
||||
|
||||
logs.append(
|
||||
{
|
||||
"step": step,
|
||||
"loss": float(loss.item()),
|
||||
"sft_loss": float(sft_loss.item()),
|
||||
"wall_clock_time": elapsed,
|
||||
"tokens_processed": tokens_processed,
|
||||
}
|
||||
)
|
||||
|
||||
if step % 5 == 0:
|
||||
print(
|
||||
f"[{optimizer_name}] Step {step}/{steps} | Loss: {loss.item():.4f} | "
|
||||
f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}"
|
||||
)
|
||||
|
||||
# 4. Post-Training Evaluation
|
||||
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
|
||||
post_eval = evaluate_model(model, tokenizer, device)
|
||||
|
||||
# Cleanup memory
|
||||
del model
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return logs, post_eval, pre_eval
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 5. DATA PREPARATION (MOCK RACF DATASET FOR PRE-FLIGHT TESTS)
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def build_mock_dataset(
|
||||
model_name: str, size: int = 20
|
||||
) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Synthesizes mock pre-tokenized target sequences representing
|
||||
|
||||
FIRT trajectories (Top-k vs Bottom-l preferences) for local dry-runs.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = []
|
||||
for idx in range(size):
|
||||
# Winning path (Correct FIRT opinions)
|
||||
text_win = (
|
||||
f"Case {idx}: Tenant wins due to Self-beseitigungsrecht § 536a BGB. "
|
||||
"Unresolved flaws: None. Authority: Refined opinion correct."
|
||||
)
|
||||
# Losing path (Flawed/contradictory reasoning)
|
||||
text_lose = (
|
||||
f"Case {idx}: Tenant loses but reasoning is faulty. "
|
||||
"Unresolved flaws: Ignored positive knowledge under § 536b BGB."
|
||||
)
|
||||
|
||||
enc_win = tokenizer(
|
||||
text_win,
|
||||
truncation=True,
|
||||
max_length=64,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
enc_lose = tokenizer(
|
||||
text_lose,
|
||||
truncation=True,
|
||||
max_length=64,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
labels_win = enc_win["input_ids"].clone()
|
||||
labels_win[enc_win["attention_mask"] == 0] = -100
|
||||
|
||||
labels_lose = enc_lose["input_ids"].clone()
|
||||
labels_lose[enc_lose["attention_mask"] == 0] = -100
|
||||
|
||||
dataset.append(
|
||||
(
|
||||
enc_win["input_ids"],
|
||||
labels_win,
|
||||
enc_lose["input_ids"],
|
||||
labels_lose,
|
||||
)
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 6. MAIN CONTROLLER
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="FIRT ORPO FCES vs AdamW Trainer")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="EleutherAI/pythia-70m",
|
||||
help="Model name (e.g. google/gemma-2-2b-it)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps", type=int, default=10, help="Number of training steps"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("=" * 80)
|
||||
print(
|
||||
" FCES VS ADAMW COMPARATIVE TRAINING BENCHMARK "
|
||||
)
|
||||
print("=" * 80)
|
||||
print(f"Device: {device} | Base Model: {args.model} | Steps: {args.steps}")
|
||||
|
||||
# Initialize pre-tokenized dataset
|
||||
print("\n[INFO] Synthesizing target training tokens...")
|
||||
dataset = build_mock_dataset(args.model)
|
||||
|
||||
# 1. Run AdamW Baseline
|
||||
adam_logs, adam_post, pre_eval = train_run(
|
||||
"AdamW", args.model, args.steps, device, dataset
|
||||
)
|
||||
|
||||
# 2. Run FCES Optimizer
|
||||
fces_logs, fces_post, _ = train_run("FCES", args.model, args.steps, device, dataset)
|
||||
|
||||
# 3. Telemetry Log Synthesis & Push
|
||||
telemetry_entries = []
|
||||
print("\n" + "-" * 30 + " FINAL RUN SUMMARY " + "-" * 30)
|
||||
print(f"{'Step':<10}{'AdamW Loss':<20}{'FCES Loss':<20}{'Speedup / Savings':<25}")
|
||||
print("-" * 75)
|
||||
|
||||
for idx in range(min(len(adam_logs), len(fces_logs))):
|
||||
step_val = adam_logs[idx]["step"]
|
||||
adam_loss = adam_logs[idx]["loss"]
|
||||
fces_loss = fces_logs[idx]["loss"]
|
||||
time_saved_pct = (
|
||||
(adam_logs[idx]["wall_clock_time"] - fces_logs[idx]["wall_clock_time"])
|
||||
/ adam_logs[idx]["wall_clock_time"]
|
||||
* 100
|
||||
)
|
||||
|
||||
if idx % (max(1, args.steps // 5)) == 0 or idx == args.steps - 1:
|
||||
print(
|
||||
f"{step_val:<10}{adam_loss:<20.4f}{fces_loss:<20.4f}"
|
||||
f"{time_saved_pct:+.2f}% wall-clock time"
|
||||
)
|
||||
|
||||
telemetry_entries.append(
|
||||
(
|
||||
"INFO",
|
||||
"comparative_step",
|
||||
f"Step {step_val} | AdamW Loss: {adam_loss:.4f} | FCES Loss: {fces_loss:.4f}",
|
||||
)
|
||||
)
|
||||
|
||||
print("-" * 75)
|
||||
print("\n>>> PRE-TRAINING OPINION (CASE 1):")
|
||||
print(pre_eval[0]["output"])
|
||||
print("\n>>> POST-TRAINING ADAMW OPINION (CASE 1):")
|
||||
print(adam_post[0]["output"])
|
||||
print("\n>>> POST-TRAINING FCES OPINION (CASE 1):")
|
||||
print(fces_post[0]["output"])
|
||||
print("=" * 80)
|
||||
|
||||
# Push telemetries
|
||||
push_to_surrealdb(telemetry_entries)
|
||||
push_to_mariadb(telemetry_entries)
|
||||
|
||||
# Write comparative summary artifact
|
||||
import json
|
||||
|
||||
results_summary = {
|
||||
"model": args.model,
|
||||
"steps": args.steps,
|
||||
"adam_metrics": adam_logs,
|
||||
"fces_metrics": fces_logs,
|
||||
"pre_eval": pre_eval,
|
||||
"adam_post_eval": adam_post,
|
||||
"fces_post_eval": fces_post,
|
||||
}
|
||||
with open("benchmark_results.json", "w", encoding="utf-8") as f:
|
||||
json.dump(results_summary, f, indent=4)
|
||||
print("[INFO] Stored comprehensive benchmark metrics to benchmark_results.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user