Files
FCES-native/benchmark_fces_vs_adam.py

435 lines
15 KiB
Python

import argparse
import os
import sys
import time
from typing import Any, Dict, List, Tuple
# Ensure python folder is in path
python_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "python")
sys.path.insert(0, python_dir)
import torch # noqa: E402
import fces_native # noqa: E402
import dspy # noqa: E402
import torch.nn.functional as F # noqa: E402
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
# ==============================================================================
# 1. DSPY SIGNATURE & SYSTEM DESIGN
# ==============================================================================
class LegalFIRT(dspy.Signature): # type: ignore[misc]
"""Flaw Induced Reasoning Traces (FIRT) for German & EU Tenancy / M&A Law."""
case_description: str = dspy.InputField(
desc="Facts of the German/EU legal case with potential hidden flaws or contradictions."
)
relevant_statutes: str = dspy.InputField(
desc="Relevant articles from the German Civil Code (BGB) or EU Directives."
)
t0_draft: str = dspy.OutputField(
desc="Initial legal assessment or draft resolution based on first-pass reasoning."
)
t1_flaws: str = dspy.OutputField(
desc="Identified legal flaws, logical bottlenecks, or contradictions (e.g. § 536b BGB exclusions)."
)
t2_refined_opinion: str = dspy.OutputField(
desc="Final, correct, and authoritative legal opinion addressing all identified flaws."
)
# Legal Cases for Pre- and Post-Training Benchmarking
eval_cases: List[Dict[str, str]] = [
{
"description": "Mieter M besichtigt im Januar eine Wohnung des Vermieters V. Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. V verlangt volle Miete.",
"statutes": "§ 535 Abs. 2 BGB (Mietzahlung), § 536 Abs. 1 BGB (Mietminderung bei Sachmangel), § 536b BGB (Kenntnis des Mieters bei Vertragsschluss).",
},
{
"description": "Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich.",
"statutes": "§ 536a Abs. 2 Nr. 2 BGB (Selbstbeseitigungsrecht bei Gefahr im Verzug), § 535 BGB.",
},
]
# ==============================================================================
# 2. ORPO LOSS WITH SOFT LABELS
# ==============================================================================
def compute_orpo_loss(
logits: torch.Tensor,
labels: torch.Tensor,
logits_lose: torch.Tensor,
labels_lose: torch.Tensor,
lambda_orpo: float = 0.1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes the combined Cross-Entropy (SFT) and Odds Ratio Preference (ORPO) Loss
using soft probability log-likelihoods.
"""
# 1. SFT Loss (Cross Entropy) on the preferred target sequence
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
sft_loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# 2. Average log probabilities for preferred (win) and dispreferred (lose) sequences
# Logits win log_softmax
log_probs_win = F.log_softmax(logits, dim=-1)
# Gather actual label log_probs
loss_mask_win = labels != -100
labels_clamped = labels.clone()
labels_clamped[~loss_mask_win] = 0
seq_log_probs_win = torch.gather(
log_probs_win, dim=-1, index=labels_clamped.unsqueeze(-1)
).squeeze(-1)
seq_log_probs_win = (seq_log_probs_win * loss_mask_win).sum(
dim=-1
) / loss_mask_win.sum(dim=-1).clamp(min=1)
# Logits lose log_softmax
log_probs_lose = F.log_softmax(logits_lose, dim=-1)
loss_mask_lose = labels_lose != -100
labels_lose_clamped = labels_lose.clone()
labels_lose_clamped[~loss_mask_lose] = 0
seq_log_probs_lose = torch.gather(
log_probs_lose, dim=-1, index=labels_lose_clamped.unsqueeze(-1)
).squeeze(-1)
seq_log_probs_lose = (seq_log_probs_lose * loss_mask_lose).sum(
dim=-1
) / loss_mask_lose.sum(dim=-1).clamp(min=1)
# 3. Odds Ratio Calculation
log_odds_win = seq_log_probs_win - torch.log1p(-torch.exp(seq_log_probs_win))
log_odds_lose = seq_log_probs_lose - torch.log1p(-torch.exp(seq_log_probs_lose))
odds_ratio = log_odds_win - log_odds_lose
orpo_loss = -F.logsigmoid(odds_ratio).mean()
# Total combined loss
total_loss = sft_loss + lambda_orpo * orpo_loss
return total_loss, sft_loss
# ==============================================================================
# 3. EVALUATION RUNNER
# ==============================================================================
def evaluate_model(
model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: str
) -> List[Dict[str, str]]:
"""Evaluates the model on legal exam cases (zero-shot) to measure T2 opinion quality."""
model.eval()
results = []
for idx, case in enumerate(eval_cases):
prompt = (
f"Case: {case['description']}\nStatutes: {case['statutes']}\n"
"Analyze and resolve using Flaw Induced Reasoning Traces (FIRT):\n"
"T0 Draft:\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=256, temperature=0.1, do_sample=False
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
results.append({"case_id": idx + 1, "output": generated_text})
return results
# ==============================================================================
# 4. TRAINING PIPELINE (ADAMW VS FCES)
# ==============================================================================
def train_run(
optimizer_name: str,
model_name: str,
steps: int,
device: str,
dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
) -> Tuple[List[Dict[str, Any]], List[Dict[str, str]], List[Dict[str, str]]]:
"""Runs the training loop and performs pre/post evaluations."""
print(f"\n[START] Initializing training run with {optimizer_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Initialize Parasitic QLoRA Extractor
extractor = ParasiticQLoRAExtractor(
QLoRAConfig(
min_rank=8,
max_rank=32,
explained_variance_threshold=0.95,
extraction_interval=5,
interesting_point_detection=True,
)
)
extractor.snapshot_base(model)
# 1. Pre-Training Evaluation
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
pre_eval = evaluate_model(model, tokenizer, device)
# 2. Setup Optimizer
model.train()
if optimizer_name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
elif optimizer_name == "FCES":
cfg = fces_native.FCESConfig()
cfg.lr = 1e-4
cfg.population_size = 8 # Balanced for resource footprint
cfg.total_steps = steps
optimizer = fces_native.FCESOptimizer(list(model.parameters()), cfg)
else:
raise ValueError("Invalid optimizer name")
# 3. Training loop
logs = []
start_time = time.perf_counter()
tokens_processed = 0
for step in range(steps):
# Cyclically fetch batches from pre-tokenized dataset
batch = dataset[step % len(dataset)]
input_win, labels_win, input_lose, labels_lose = [t.to(device) for t in batch]
optimizer.zero_grad()
# Forward pass for win (preferred)
outputs_win = model(input_win, labels=labels_win)
logits_win = outputs_win.logits
# Forward pass for lose (dispreferred)
with torch.no_grad():
outputs_lose = model(input_lose)
logits_lose = outputs_lose.logits
# Compute combined ORPO loss
loss, sft_loss = compute_orpo_loss(
logits_win, labels_win, logits_lose, labels_lose
)
loss.backward()
optimizer.step()
if optimizer_name == "FCES":
optimizer.update_fitness(float(loss.item()))
# Call parasitic extractor
if extractor.should_extract(step, float(loss.item())):
metrics = {
"loss": float(loss.item()),
"sft_loss": float(sft_loss.item()),
"optimizer": optimizer_name,
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
}
extractor.extract_adapters(model, step, metrics)
# Tracking metrics
elapsed = time.perf_counter() - start_time
batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item())
tokens_processed += batch_tokens
logs.append(
{
"step": step,
"loss": float(loss.item()),
"sft_loss": float(sft_loss.item()),
"wall_clock_time": elapsed,
"tokens_processed": tokens_processed,
}
)
if step % 5 == 0:
print(
f"[{optimizer_name}] Step {step}/{steps} | Loss: {loss.item():.4f} | "
f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}"
)
# Save the extracted adapter library
library_path = f"parasitic_adapters_{optimizer_name.lower()}_step{steps}.pt"
extractor.save_library(library_path)
# 4. Post-Training Evaluation
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
post_eval = evaluate_model(model, tokenizer, device)
# Cleanup memory
del model
if device == "cuda":
torch.cuda.empty_cache()
return logs, post_eval, pre_eval
# ==============================================================================
# 5. DATA PREPARATION (MOCK RACF DATASET FOR PRE-FLIGHT TESTS)
# ==============================================================================
def build_mock_dataset(
model_name: str, size: int = 20
) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Synthesizes mock pre-tokenized target sequences representing
FIRT trajectories (Top-k vs Bottom-l preferences) for local dry-runs.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
dataset = []
for idx in range(size):
# Winning path (Correct FIRT opinions)
text_win = (
f"Case {idx}: Tenant wins due to Self-beseitigungsrecht § 536a BGB. "
"Unresolved flaws: None. Authority: Refined opinion correct."
)
# Losing path (Flawed/contradictory reasoning)
text_lose = (
f"Case {idx}: Tenant loses but reasoning is faulty. "
"Unresolved flaws: Ignored positive knowledge under § 536b BGB."
)
enc_win = tokenizer(
text_win,
truncation=True,
max_length=64,
padding="max_length",
return_tensors="pt",
)
enc_lose = tokenizer(
text_lose,
truncation=True,
max_length=64,
padding="max_length",
return_tensors="pt",
)
labels_win = enc_win["input_ids"].clone()
labels_win[enc_win["attention_mask"] == 0] = -100
labels_lose = enc_lose["input_ids"].clone()
labels_lose[enc_lose["attention_mask"] == 0] = -100
dataset.append(
(
enc_win["input_ids"],
labels_win,
enc_lose["input_ids"],
labels_lose,
)
)
return dataset
# ==============================================================================
# 6. MAIN CONTROLLER
# ==============================================================================
def main() -> None:
parser = argparse.ArgumentParser(description="FIRT ORPO FCES vs AdamW Trainer")
parser.add_argument(
"--model",
type=str,
default="EleutherAI/pythia-70m",
help="Model name (e.g. google/gemma-2-2b-it)",
)
parser.add_argument(
"--steps", type=int, default=10, help="Number of training steps"
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 80)
print(
" FCES VS ADAMW COMPARATIVE TRAINING BENCHMARK "
)
print("=" * 80)
print(f"Device: {device} | Base Model: {args.model} | Steps: {args.steps}")
# Initialize pre-tokenized dataset
print("\n[INFO] Synthesizing target training tokens...")
dataset = build_mock_dataset(args.model)
# 1. Run AdamW Baseline
adam_logs, adam_post, pre_eval = train_run(
"AdamW", args.model, args.steps, device, dataset
)
# 2. Run FCES Optimizer
fces_logs, fces_post, _ = train_run("FCES", args.model, args.steps, device, dataset)
# 3. Telemetry Log Synthesis & Push
telemetry_entries = []
print("\n" + "-" * 30 + " FINAL RUN SUMMARY " + "-" * 30)
print(f"{'Step':<10}{'AdamW Loss':<20}{'FCES Loss':<20}{'Speedup / Savings':<25}")
print("-" * 75)
for idx in range(min(len(adam_logs), len(fces_logs))):
step_val = adam_logs[idx]["step"]
adam_loss = adam_logs[idx]["loss"]
fces_loss = fces_logs[idx]["loss"]
time_saved_pct = (
(adam_logs[idx]["wall_clock_time"] - fces_logs[idx]["wall_clock_time"])
/ adam_logs[idx]["wall_clock_time"]
* 100
)
if idx % (max(1, args.steps // 5)) == 0 or idx == args.steps - 1:
print(
f"{step_val:<10}{adam_loss:<20.4f}{fces_loss:<20.4f}"
f"{time_saved_pct:+.2f}% wall-clock time"
)
telemetry_entries.append(
(
"INFO",
"comparative_step",
f"Step {step_val} | AdamW Loss: {adam_loss:.4f} | FCES Loss: {fces_loss:.4f}",
)
)
print("-" * 75)
print("\n>>> PRE-TRAINING OPINION (CASE 1):")
print(pre_eval[0]["output"])
print("\n>>> POST-TRAINING ADAMW OPINION (CASE 1):")
print(adam_post[0]["output"])
print("\n>>> POST-TRAINING FCES OPINION (CASE 1):")
print(fces_post[0]["output"])
print("=" * 80)
# Push telemetries
push_to_surrealdb(telemetry_entries)
push_to_mariadb(telemetry_entries)
# Write comparative summary artifact
import json
results_summary = {
"model": args.model,
"steps": args.steps,
"adam_metrics": adam_logs,
"fces_metrics": fces_logs,
"pre_eval": pre_eval,
"adam_post_eval": adam_post,
"fces_post_eval": fces_post,
}
with open("benchmark_results.json", "w", encoding="utf-8") as f:
json.dump(results_summary, f, indent=4)
print("[INFO] Stored comprehensive benchmark metrics to benchmark_results.json")
if __name__ == "__main__":
main()