import argparse import os import sys import time from typing import Any, Dict, List, Tuple # Ensure python folder is in path python_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "python") sys.path.insert(0, python_dir) import torch # noqa: E402 import fces_native # noqa: E402 import dspy # noqa: E402 import torch.nn.functional as F # noqa: E402 from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402 from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402 from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402 from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402 # ============================================================================== # 1. DSPY SIGNATURE & SYSTEM DESIGN # ============================================================================== class LegalFIRT(dspy.Signature): # type: ignore[misc] """Flaw Induced Reasoning Traces (FIRT) for German & EU Tenancy / M&A Law.""" case_description: str = dspy.InputField( desc="Facts of the German/EU legal case with potential hidden flaws or contradictions." ) relevant_statutes: str = dspy.InputField( desc="Relevant articles from the German Civil Code (BGB) or EU Directives." ) t0_draft: str = dspy.OutputField( desc="Initial legal assessment or draft resolution based on first-pass reasoning." ) t1_flaws: str = dspy.OutputField( desc="Identified legal flaws, logical bottlenecks, or contradictions (e.g. § 536b BGB exclusions)." ) t2_refined_opinion: str = dspy.OutputField( desc="Final, correct, and authoritative legal opinion addressing all identified flaws." ) # Legal Cases for Pre- and Post-Training Benchmarking eval_cases: List[Dict[str, str]] = [ { "description": "Mieter M besichtigt im Januar eine Wohnung des Vermieters V. Schimmelfleck im Schlafzimmer. V sagt er wird es 'irgendwann mal wegmachen'. M unterschreibt Mietvertrag vorbehaltlos und zieht am 1. Februar ein. Im März mindert M die Miete selbstständig um 20% wegen des Schimmels. V verlangt volle Miete.", "statutes": "§ 535 Abs. 2 BGB (Mietzahlung), § 536 Abs. 1 BGB (Mietminderung bei Sachmangel), § 536b BGB (Kenntnis des Mieters bei Vertragsschluss).", }, { "description": "Toilettenspülung bei F fällt am Freitagabend aus, Wasser läuft ununterbrochen und droht überzulaufen. V ist auf Segeltrip und nicht erreichbar. Kein Hausmeister. F beauftragt am Samstagmorgen den Notdienst K, bezahlt 300 € und fordert Erstattung von V. V weigert sich.", "statutes": "§ 536a Abs. 2 Nr. 2 BGB (Selbstbeseitigungsrecht bei Gefahr im Verzug), § 535 BGB.", }, ] # ============================================================================== # 2. ORPO LOSS WITH SOFT LABELS # ============================================================================== def compute_orpo_loss( logits: torch.Tensor, labels: torch.Tensor, logits_lose: torch.Tensor, labels_lose: torch.Tensor, lambda_orpo: float = 0.1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes the combined Cross-Entropy (SFT) and Odds Ratio Preference (ORPO) Loss using soft probability log-likelihoods. """ # 1. SFT Loss (Cross Entropy) on the preferred target sequence shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() sft_loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) # 2. Average log probabilities for preferred (win) and dispreferred (lose) sequences # Logits win log_softmax log_probs_win = F.log_softmax(logits, dim=-1) # Gather actual label log_probs loss_mask_win = labels != -100 labels_clamped = labels.clone() labels_clamped[~loss_mask_win] = 0 seq_log_probs_win = torch.gather( log_probs_win, dim=-1, index=labels_clamped.unsqueeze(-1) ).squeeze(-1) seq_log_probs_win = (seq_log_probs_win * loss_mask_win).sum( dim=-1 ) / loss_mask_win.sum(dim=-1).clamp(min=1) # Logits lose log_softmax log_probs_lose = F.log_softmax(logits_lose, dim=-1) loss_mask_lose = labels_lose != -100 labels_lose_clamped = labels_lose.clone() labels_lose_clamped[~loss_mask_lose] = 0 seq_log_probs_lose = torch.gather( log_probs_lose, dim=-1, index=labels_lose_clamped.unsqueeze(-1) ).squeeze(-1) seq_log_probs_lose = (seq_log_probs_lose * loss_mask_lose).sum( dim=-1 ) / loss_mask_lose.sum(dim=-1).clamp(min=1) # 3. Odds Ratio Calculation log_odds_win = seq_log_probs_win - torch.log1p(-torch.exp(seq_log_probs_win)) log_odds_lose = seq_log_probs_lose - torch.log1p(-torch.exp(seq_log_probs_lose)) odds_ratio = log_odds_win - log_odds_lose orpo_loss = -F.logsigmoid(odds_ratio).mean() # Total combined loss total_loss = sft_loss + lambda_orpo * orpo_loss return total_loss, sft_loss # ============================================================================== # 3. EVALUATION RUNNER # ============================================================================== def evaluate_model( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: str ) -> List[Dict[str, str]]: """Evaluates the model on legal exam cases (zero-shot) to measure T2 opinion quality.""" model.eval() results = [] for idx, case in enumerate(eval_cases): prompt = ( f"Case: {case['description']}\nStatutes: {case['statutes']}\n" "Analyze and resolve using Flaw Induced Reasoning Traces (FIRT):\n" "T0 Draft:\n" ) inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, temperature=0.1, do_sample=False ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) results.append({"case_id": idx + 1, "output": generated_text}) return results # ============================================================================== # 4. TRAINING PIPELINE (ADAMW VS FCES) # ============================================================================== def train_run( optimizer_name: str, model_name: str, steps: int, device: str, dataset: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ) -> Tuple[List[Dict[str, Any]], List[Dict[str, str]], List[Dict[str, str]]]: """Runs the training loop and performs pre/post evaluations.""" print(f"\n[START] Initializing training run with {optimizer_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name).to(device) # Initialize Parasitic QLoRA Extractor extractor = ParasiticQLoRAExtractor( QLoRAConfig( min_rank=8, max_rank=32, explained_variance_threshold=0.95, extraction_interval=5, interesting_point_detection=True, ) ) extractor.snapshot_base(model) # Initialize Expert Manifold Aligner aligner = ExpertManifoldAligner(model) # 1. Pre-Training Evaluation print(f"[{optimizer_name}] Running Pre-Training Evaluation...") pre_eval = evaluate_model(model, tokenizer, device) # 2. Setup Optimizer model.train() if optimizer_name == "AdamW": optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) elif optimizer_name == "FCES": cfg = fces_native.FCESConfig() cfg.lr = 1e-4 cfg.population_size = 8 # Balanced for resource footprint cfg.total_steps = steps optimizer = fces_native.FCESOptimizer(list(model.parameters()), cfg) else: raise ValueError("Invalid optimizer name") # 3. Training loop logs = [] start_time = time.perf_counter() tokens_processed = 0 for step in range(steps): # Cyclically fetch batches from pre-tokenized dataset batch = dataset[step % len(dataset)] input_win, labels_win, input_lose, labels_lose = [t.to(device) for t in batch] optimizer.zero_grad() # Forward pass for win (preferred) outputs_win = model(input_win, labels=labels_win) logits_win = outputs_win.logits # Forward pass for lose (dispreferred) with torch.no_grad(): outputs_lose = model(input_lose) logits_lose = outputs_lose.logits # Compute combined ORPO loss loss, sft_loss = compute_orpo_loss( logits_win, labels_win, logits_lose, labels_lose ) loss.backward() optimizer.step() if optimizer_name == "FCES": optimizer.update_fitness(float(loss.item())) # Track per-step weight delta for manifold alignment aligner.track_step(model) # Call parasitic extractor if extractor.should_extract(step, float(loss.item())): metrics: Dict[str, Any] = { "loss": float(loss.item()), "sft_loss": float(sft_loss.item()), "optimizer": optimizer_name, "spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0), } if optimizer_name == "FCES": metrics["fces_fitness"] = optimizer.get_active_controller_fitness() metrics["fces_controller_id"] = optimizer.get_active_controller_id() adapter = extractor.extract_adapters(model, step, metrics) aligner.tag_adapter(adapter) profile = aligner.profile_adapter(adapter) print( f"[{optimizer_name}] Adapter '{adapter.adapter_id}' | " f"tags={adapter.domain_tags} | " f"statute={profile['statute_recall']:.2f} " f"logic={profile['logic_reasoning']:.2f} " f"style={profile['style_gutachtenstil']:.2f}" ) # Tracking metrics elapsed = time.perf_counter() - start_time batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item()) tokens_processed += batch_tokens logs.append( { "step": step, "loss": float(loss.item()), "sft_loss": float(sft_loss.item()), "wall_clock_time": elapsed, "tokens_processed": tokens_processed, } ) if step % 5 == 0: print( f"[{optimizer_name}] Step {step}/{steps} | Loss: {loss.item():.4f} | " f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}" ) # Save the extracted adapter library library_path = f"parasitic_adapters_{optimizer_name.lower()}_step{steps}.pt" extractor.save_library(library_path) # 4. Post-Training Evaluation print(f"[{optimizer_name}] Running Post-Training Evaluation...") post_eval = evaluate_model(model, tokenizer, device) # Cleanup memory del model if device == "cuda": torch.cuda.empty_cache() return logs, post_eval, pre_eval # ============================================================================== # 5. DATA PREPARATION (MOCK RACF DATASET FOR PRE-FLIGHT TESTS) # ============================================================================== def build_mock_dataset( model_name: str, size: int = 20 ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """Synthesizes mock pre-tokenized target sequences representing FIRT trajectories (Top-k vs Bottom-l preferences) for local dry-runs. """ tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token dataset = [] for idx in range(size): # Winning path (Correct FIRT opinions) text_win = ( f"Case {idx}: Tenant wins due to Self-beseitigungsrecht § 536a BGB. " "Unresolved flaws: None. Authority: Refined opinion correct." ) # Losing path (Flawed/contradictory reasoning) text_lose = ( f"Case {idx}: Tenant loses but reasoning is faulty. " "Unresolved flaws: Ignored positive knowledge under § 536b BGB." ) enc_win = tokenizer( text_win, truncation=True, max_length=64, padding="max_length", return_tensors="pt", ) enc_lose = tokenizer( text_lose, truncation=True, max_length=64, padding="max_length", return_tensors="pt", ) labels_win = enc_win["input_ids"].clone() labels_win[enc_win["attention_mask"] == 0] = -100 labels_lose = enc_lose["input_ids"].clone() labels_lose[enc_lose["attention_mask"] == 0] = -100 dataset.append( ( enc_win["input_ids"], labels_win, enc_lose["input_ids"], labels_lose, ) ) return dataset # ============================================================================== # 6. MAIN CONTROLLER # ============================================================================== def main() -> None: parser = argparse.ArgumentParser(description="FIRT ORPO FCES vs AdamW Trainer") parser.add_argument( "--model", type=str, default="EleutherAI/pythia-70m", help="Model name (e.g. google/gemma-2-2b-it)", ) parser.add_argument( "--steps", type=int, default=10, help="Number of training steps" ) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" print("=" * 80) print( " FCES VS ADAMW COMPARATIVE TRAINING BENCHMARK " ) print("=" * 80) print(f"Device: {device} | Base Model: {args.model} | Steps: {args.steps}") # Initialize pre-tokenized dataset print("\n[INFO] Synthesizing target training tokens...") dataset = build_mock_dataset(args.model) # 1. Run AdamW Baseline adam_logs, adam_post, pre_eval = train_run( "AdamW", args.model, args.steps, device, dataset ) # 2. Run FCES Optimizer fces_logs, fces_post, _ = train_run("FCES", args.model, args.steps, device, dataset) # 3. Telemetry Log Synthesis & Push telemetry_entries = [] print("\n" + "-" * 30 + " FINAL RUN SUMMARY " + "-" * 30) print(f"{'Step':<10}{'AdamW Loss':<20}{'FCES Loss':<20}{'Speedup / Savings':<25}") print("-" * 75) for idx in range(min(len(adam_logs), len(fces_logs))): step_val = adam_logs[idx]["step"] adam_loss = adam_logs[idx]["loss"] fces_loss = fces_logs[idx]["loss"] time_saved_pct = ( (adam_logs[idx]["wall_clock_time"] - fces_logs[idx]["wall_clock_time"]) / adam_logs[idx]["wall_clock_time"] * 100 ) if idx % (max(1, args.steps // 5)) == 0 or idx == args.steps - 1: print( f"{step_val:<10}{adam_loss:<20.4f}{fces_loss:<20.4f}" f"{time_saved_pct:+.2f}% wall-clock time" ) telemetry_entries.append( ( "INFO", "comparative_step", f"Step {step_val} | AdamW Loss: {adam_loss:.4f} | FCES Loss: {fces_loss:.4f}", ) ) print("-" * 75) print("\n>>> PRE-TRAINING OPINION (CASE 1):") print(pre_eval[0]["output"]) print("\n>>> POST-TRAINING ADAMW OPINION (CASE 1):") print(adam_post[0]["output"]) print("\n>>> POST-TRAINING FCES OPINION (CASE 1):") print(fces_post[0]["output"]) print("=" * 80) # Push telemetries push_to_surrealdb(telemetry_entries) push_to_mariadb(telemetry_entries) # Write comparative summary artifact import json results_summary = { "model": args.model, "steps": args.steps, "adam_metrics": adam_logs, "fces_metrics": fces_logs, "pre_eval": pre_eval, "adam_post_eval": adam_post, "fces_post_eval": fces_post, } with open("benchmark_results.json", "w", encoding="utf-8") as f: json.dump(results_summary, f, indent=4) print("[INFO] Stored comprehensive benchmark metrics to benchmark_results.json") if __name__ == "__main__": main()