feat: implement parasitic QLoRA adapter extraction and unit tests

This commit is contained in:
AI-anonymous
2026-05-20 15:03:34 +02:00
parent a1c123e590
commit 7e2e86d98c
4 changed files with 636 additions and 0 deletions

View File

@@ -14,6 +14,7 @@ import dspy # noqa: E402
import torch.nn.functional as F # noqa: E402
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
# ==============================================================================
# 1. DSPY SIGNATURE & SYSTEM DESIGN
@@ -162,6 +163,18 @@ def train_run(
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Initialize Parasitic QLoRA Extractor
extractor = ParasiticQLoRAExtractor(
QLoRAConfig(
min_rank=8,
max_rank=32,
explained_variance_threshold=0.95,
extraction_interval=5,
interesting_point_detection=True,
)
)
extractor.snapshot_base(model)
# 1. Pre-Training Evaluation
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
pre_eval = evaluate_model(model, tokenizer, device)
@@ -211,6 +224,16 @@ def train_run(
if optimizer_name == "FCES":
optimizer.update_fitness(float(loss.item()))
# Call parasitic extractor
if extractor.should_extract(step, float(loss.item())):
metrics = {
"loss": float(loss.item()),
"sft_loss": float(sft_loss.item()),
"optimizer": optimizer_name,
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
}
extractor.extract_adapters(model, step, metrics)
# Tracking metrics
elapsed = time.perf_counter() - start_time
batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item())
@@ -232,6 +255,10 @@ def train_run(
f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}"
)
# Save the extracted adapter library
library_path = f"parasitic_adapters_{optimizer_name.lower()}_step{steps}.pt"
extractor.save_library(library_path)
# 4. Post-Training Evaluation
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
post_eval = evaluate_model(model, tokenizer, device)