435 lines
15 KiB
Python
435 lines
15 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
import time
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
# Ensure python folder is in path
|
|
python_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "python")
|
|
sys.path.insert(0, python_dir)
|
|
|
|
import torch # noqa: E402
|
|
import fces_native # noqa: E402
|
|
import dspy # noqa: E402
|
|
import torch.nn.functional as F # noqa: E402
|
|
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
|
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
|
|
|
# ==============================================================================
|
|
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
|
# ==============================================================================
|
|
|
|
|
|
class LegalFIRT(dspy.Signature): # type: ignore[misc]
|
|
"""Flaw Induced Reasoning Traces (FIRT) for German & EU Tenancy / M&A Law."""
|
|
|
|
case_description: str = dspy.InputField(
|
|
desc="Facts of the German/EU legal case with potential hidden flaws or contradictions."
|
|
)
|
|
relevant_statutes: str = dspy.InputField(
|
|
desc="Relevant articles from the German Civil Code (BGB) or EU Directives."
|
|
)
|
|
t0_draft: str = dspy.OutputField(
|
|
desc="Initial legal assessment or draft resolution based on first-pass reasoning."
|
|
)
|
|
t1_flaws: str = dspy.OutputField(
|
|
desc="Identified legal flaws, logical bottlenecks, or contradictions (e.g. § 536b BGB exclusions)."
|
|
)
|
|
t2_refined_opinion: str = dspy.OutputField(
|
|
desc="Final, correct, and authoritative legal opinion addressing all identified flaws."
|
|
)
|
|
|
|
|
|
# Legal Cases for Pre- and Post-Training Benchmarking
|
|
eval_cases: List[Dict[str, str]] = [
|
|
{
|
|
"description": "Mieter M besichtigt im Januar eine Wohnung des Vermieters V. Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. V verlangt volle Miete.",
|
|
"statutes": "§ 535 Abs. 2 BGB (Mietzahlung), § 536 Abs. 1 BGB (Mietminderung bei Sachmangel), § 536b BGB (Kenntnis des Mieters bei Vertragsschluss).",
|
|
},
|
|
{
|
|
"description": "Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich.",
|
|
"statutes": "§ 536a Abs. 2 Nr. 2 BGB (Selbstbeseitigungsrecht bei Gefahr im Verzug), § 535 BGB.",
|
|
},
|
|
]
|
|
|
|
# ==============================================================================
|
|
# 2. ORPO LOSS WITH SOFT LABELS
|
|
# ==============================================================================
|
|
|
|
|
|
def compute_orpo_loss(
|
|
logits: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
logits_lose: torch.Tensor,
|
|
labels_lose: torch.Tensor,
|
|
lambda_orpo: float = 0.1,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Computes the combined Cross-Entropy (SFT) and Odds Ratio Preference (ORPO) Loss
|
|
|
|
using soft probability log-likelihoods.
|
|
"""
|
|
# 1. SFT Loss (Cross Entropy) on the preferred target sequence
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
sft_loss = F.cross_entropy(
|
|
shift_logits.view(-1, shift_logits.size(-1)),
|
|
shift_labels.view(-1),
|
|
ignore_index=-100,
|
|
)
|
|
|
|
# 2. Average log probabilities for preferred (win) and dispreferred (lose) sequences
|
|
# Logits win log_softmax
|
|
log_probs_win = F.log_softmax(logits, dim=-1)
|
|
# Gather actual label log_probs
|
|
loss_mask_win = labels != -100
|
|
labels_clamped = labels.clone()
|
|
labels_clamped[~loss_mask_win] = 0
|
|
seq_log_probs_win = torch.gather(
|
|
log_probs_win, dim=-1, index=labels_clamped.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
seq_log_probs_win = (seq_log_probs_win * loss_mask_win).sum(
|
|
dim=-1
|
|
) / loss_mask_win.sum(dim=-1).clamp(min=1)
|
|
|
|
# Logits lose log_softmax
|
|
log_probs_lose = F.log_softmax(logits_lose, dim=-1)
|
|
loss_mask_lose = labels_lose != -100
|
|
labels_lose_clamped = labels_lose.clone()
|
|
labels_lose_clamped[~loss_mask_lose] = 0
|
|
seq_log_probs_lose = torch.gather(
|
|
log_probs_lose, dim=-1, index=labels_lose_clamped.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
seq_log_probs_lose = (seq_log_probs_lose * loss_mask_lose).sum(
|
|
dim=-1
|
|
) / loss_mask_lose.sum(dim=-1).clamp(min=1)
|
|
|
|
# 3. Odds Ratio Calculation
|
|
log_odds_win = seq_log_probs_win - torch.log1p(-torch.exp(seq_log_probs_win))
|
|
log_odds_lose = seq_log_probs_lose - torch.log1p(-torch.exp(seq_log_probs_lose))
|
|
|
|
odds_ratio = log_odds_win - log_odds_lose
|
|
orpo_loss = -F.logsigmoid(odds_ratio).mean()
|
|
|
|
# Total combined loss
|
|
total_loss = sft_loss + lambda_orpo * orpo_loss
|
|
return total_loss, sft_loss
|
|
|
|
|
|
# ==============================================================================
|
|
# 3. EVALUATION RUNNER
|
|
# ==============================================================================
|
|
|
|
|
|
def evaluate_model(
|
|
model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: str
|
|
) -> List[Dict[str, str]]:
|
|
"""Evaluates the model on legal exam cases (zero-shot) to measure T2 opinion quality."""
|
|
model.eval()
|
|
results = []
|
|
for idx, case in enumerate(eval_cases):
|
|
prompt = (
|
|
f"Case: {case['description']}\nStatutes: {case['statutes']}\n"
|
|
"Analyze and resolve using Flaw Induced Reasoning Traces (FIRT):\n"
|
|
"T0 Draft:\n"
|
|
)
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
**inputs, max_new_tokens=256, temperature=0.1, do_sample=False
|
|
)
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
results.append({"case_id": idx + 1, "output": generated_text})
|
|
return results
|
|
|
|
|
|
# ==============================================================================
|
|
# 4. TRAINING PIPELINE (ADAMW VS FCES)
|
|
# ==============================================================================
|
|
|
|
|
|
def train_run(
|
|
optimizer_name: str,
|
|
model_name: str,
|
|
steps: int,
|
|
device: str,
|
|
dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, str]], List[Dict[str, str]]]:
|
|
"""Runs the training loop and performs pre/post evaluations."""
|
|
print(f"\n[START] Initializing training run with {optimizer_name}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
|
|
|
# Initialize Parasitic QLoRA Extractor
|
|
extractor = ParasiticQLoRAExtractor(
|
|
QLoRAConfig(
|
|
min_rank=8,
|
|
max_rank=32,
|
|
explained_variance_threshold=0.95,
|
|
extraction_interval=5,
|
|
interesting_point_detection=True,
|
|
)
|
|
)
|
|
extractor.snapshot_base(model)
|
|
|
|
# 1. Pre-Training Evaluation
|
|
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
|
pre_eval = evaluate_model(model, tokenizer, device)
|
|
|
|
# 2. Setup Optimizer
|
|
model.train()
|
|
if optimizer_name == "AdamW":
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
|
elif optimizer_name == "FCES":
|
|
cfg = fces_native.FCESConfig()
|
|
cfg.lr = 1e-4
|
|
cfg.population_size = 8 # Balanced for resource footprint
|
|
cfg.total_steps = steps
|
|
optimizer = fces_native.FCESOptimizer(list(model.parameters()), cfg)
|
|
else:
|
|
raise ValueError("Invalid optimizer name")
|
|
|
|
# 3. Training loop
|
|
logs = []
|
|
start_time = time.perf_counter()
|
|
tokens_processed = 0
|
|
|
|
for step in range(steps):
|
|
# Cyclically fetch batches from pre-tokenized dataset
|
|
batch = dataset[step % len(dataset)]
|
|
input_win, labels_win, input_lose, labels_lose = [t.to(device) for t in batch]
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Forward pass for win (preferred)
|
|
outputs_win = model(input_win, labels=labels_win)
|
|
logits_win = outputs_win.logits
|
|
|
|
# Forward pass for lose (dispreferred)
|
|
with torch.no_grad():
|
|
outputs_lose = model(input_lose)
|
|
logits_lose = outputs_lose.logits
|
|
|
|
# Compute combined ORPO loss
|
|
loss, sft_loss = compute_orpo_loss(
|
|
logits_win, labels_win, logits_lose, labels_lose
|
|
)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if optimizer_name == "FCES":
|
|
optimizer.update_fitness(float(loss.item()))
|
|
|
|
# Call parasitic extractor
|
|
if extractor.should_extract(step, float(loss.item())):
|
|
metrics = {
|
|
"loss": float(loss.item()),
|
|
"sft_loss": float(sft_loss.item()),
|
|
"optimizer": optimizer_name,
|
|
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
|
|
}
|
|
extractor.extract_adapters(model, step, metrics)
|
|
|
|
# Tracking metrics
|
|
elapsed = time.perf_counter() - start_time
|
|
batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item())
|
|
tokens_processed += batch_tokens
|
|
|
|
logs.append(
|
|
{
|
|
"step": step,
|
|
"loss": float(loss.item()),
|
|
"sft_loss": float(sft_loss.item()),
|
|
"wall_clock_time": elapsed,
|
|
"tokens_processed": tokens_processed,
|
|
}
|
|
)
|
|
|
|
if step % 5 == 0:
|
|
print(
|
|
f"[{optimizer_name}] Step {step}/{steps} | Loss: {loss.item():.4f} | "
|
|
f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}"
|
|
)
|
|
|
|
# Save the extracted adapter library
|
|
library_path = f"parasitic_adapters_{optimizer_name.lower()}_step{steps}.pt"
|
|
extractor.save_library(library_path)
|
|
|
|
# 4. Post-Training Evaluation
|
|
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
|
|
post_eval = evaluate_model(model, tokenizer, device)
|
|
|
|
# Cleanup memory
|
|
del model
|
|
if device == "cuda":
|
|
torch.cuda.empty_cache()
|
|
|
|
return logs, post_eval, pre_eval
|
|
|
|
|
|
# ==============================================================================
|
|
# 5. DATA PREPARATION (MOCK RACF DATASET FOR PRE-FLIGHT TESTS)
|
|
# ==============================================================================
|
|
|
|
|
|
def build_mock_dataset(
|
|
model_name: str, size: int = 20
|
|
) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
"""Synthesizes mock pre-tokenized target sequences representing
|
|
|
|
FIRT trajectories (Top-k vs Bottom-l preferences) for local dry-runs.
|
|
"""
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
dataset = []
|
|
for idx in range(size):
|
|
# Winning path (Correct FIRT opinions)
|
|
text_win = (
|
|
f"Case {idx}: Tenant wins due to Self-beseitigungsrecht § 536a BGB. "
|
|
"Unresolved flaws: None. Authority: Refined opinion correct."
|
|
)
|
|
# Losing path (Flawed/contradictory reasoning)
|
|
text_lose = (
|
|
f"Case {idx}: Tenant loses but reasoning is faulty. "
|
|
"Unresolved flaws: Ignored positive knowledge under § 536b BGB."
|
|
)
|
|
|
|
enc_win = tokenizer(
|
|
text_win,
|
|
truncation=True,
|
|
max_length=64,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
enc_lose = tokenizer(
|
|
text_lose,
|
|
truncation=True,
|
|
max_length=64,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
|
|
labels_win = enc_win["input_ids"].clone()
|
|
labels_win[enc_win["attention_mask"] == 0] = -100
|
|
|
|
labels_lose = enc_lose["input_ids"].clone()
|
|
labels_lose[enc_lose["attention_mask"] == 0] = -100
|
|
|
|
dataset.append(
|
|
(
|
|
enc_win["input_ids"],
|
|
labels_win,
|
|
enc_lose["input_ids"],
|
|
labels_lose,
|
|
)
|
|
)
|
|
return dataset
|
|
|
|
|
|
# ==============================================================================
|
|
# 6. MAIN CONTROLLER
|
|
# ==============================================================================
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="FIRT ORPO FCES vs AdamW Trainer")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="EleutherAI/pythia-70m",
|
|
help="Model name (e.g. google/gemma-2-2b-it)",
|
|
)
|
|
parser.add_argument(
|
|
"--steps", type=int, default=10, help="Number of training steps"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print("=" * 80)
|
|
print(
|
|
" FCES VS ADAMW COMPARATIVE TRAINING BENCHMARK "
|
|
)
|
|
print("=" * 80)
|
|
print(f"Device: {device} | Base Model: {args.model} | Steps: {args.steps}")
|
|
|
|
# Initialize pre-tokenized dataset
|
|
print("\n[INFO] Synthesizing target training tokens...")
|
|
dataset = build_mock_dataset(args.model)
|
|
|
|
# 1. Run AdamW Baseline
|
|
adam_logs, adam_post, pre_eval = train_run(
|
|
"AdamW", args.model, args.steps, device, dataset
|
|
)
|
|
|
|
# 2. Run FCES Optimizer
|
|
fces_logs, fces_post, _ = train_run("FCES", args.model, args.steps, device, dataset)
|
|
|
|
# 3. Telemetry Log Synthesis & Push
|
|
telemetry_entries = []
|
|
print("\n" + "-" * 30 + " FINAL RUN SUMMARY " + "-" * 30)
|
|
print(f"{'Step':<10}{'AdamW Loss':<20}{'FCES Loss':<20}{'Speedup / Savings':<25}")
|
|
print("-" * 75)
|
|
|
|
for idx in range(min(len(adam_logs), len(fces_logs))):
|
|
step_val = adam_logs[idx]["step"]
|
|
adam_loss = adam_logs[idx]["loss"]
|
|
fces_loss = fces_logs[idx]["loss"]
|
|
time_saved_pct = (
|
|
(adam_logs[idx]["wall_clock_time"] - fces_logs[idx]["wall_clock_time"])
|
|
/ adam_logs[idx]["wall_clock_time"]
|
|
* 100
|
|
)
|
|
|
|
if idx % (max(1, args.steps // 5)) == 0 or idx == args.steps - 1:
|
|
print(
|
|
f"{step_val:<10}{adam_loss:<20.4f}{fces_loss:<20.4f}"
|
|
f"{time_saved_pct:+.2f}% wall-clock time"
|
|
)
|
|
|
|
telemetry_entries.append(
|
|
(
|
|
"INFO",
|
|
"comparative_step",
|
|
f"Step {step_val} | AdamW Loss: {adam_loss:.4f} | FCES Loss: {fces_loss:.4f}",
|
|
)
|
|
)
|
|
|
|
print("-" * 75)
|
|
print("\n>>> PRE-TRAINING OPINION (CASE 1):")
|
|
print(pre_eval[0]["output"])
|
|
print("\n>>> POST-TRAINING ADAMW OPINION (CASE 1):")
|
|
print(adam_post[0]["output"])
|
|
print("\n>>> POST-TRAINING FCES OPINION (CASE 1):")
|
|
print(fces_post[0]["output"])
|
|
print("=" * 80)
|
|
|
|
# Push telemetries
|
|
push_to_surrealdb(telemetry_entries)
|
|
push_to_mariadb(telemetry_entries)
|
|
|
|
# Write comparative summary artifact
|
|
import json
|
|
|
|
results_summary = {
|
|
"model": args.model,
|
|
"steps": args.steps,
|
|
"adam_metrics": adam_logs,
|
|
"fces_metrics": fces_logs,
|
|
"pre_eval": pre_eval,
|
|
"adam_post_eval": adam_post,
|
|
"fces_post_eval": fces_post,
|
|
}
|
|
with open("benchmark_results.json", "w", encoding="utf-8") as f:
|
|
json.dump(results_summary, f, indent=4)
|
|
print("[INFO] Stored comprehensive benchmark metrics to benchmark_results.json")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|