feat: implement parasitic QLoRA adapter extraction and unit tests
This commit is contained in:
@@ -14,6 +14,7 @@ import dspy # noqa: E402
|
||||
import torch.nn.functional as F # noqa: E402
|
||||
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
||||
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
||||
|
||||
# ==============================================================================
|
||||
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
||||
@@ -162,6 +163,18 @@ def train_run(
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
||||
|
||||
# Initialize Parasitic QLoRA Extractor
|
||||
extractor = ParasiticQLoRAExtractor(
|
||||
QLoRAConfig(
|
||||
min_rank=8,
|
||||
max_rank=32,
|
||||
explained_variance_threshold=0.95,
|
||||
extraction_interval=5,
|
||||
interesting_point_detection=True,
|
||||
)
|
||||
)
|
||||
extractor.snapshot_base(model)
|
||||
|
||||
# 1. Pre-Training Evaluation
|
||||
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
||||
pre_eval = evaluate_model(model, tokenizer, device)
|
||||
@@ -211,6 +224,16 @@ def train_run(
|
||||
if optimizer_name == "FCES":
|
||||
optimizer.update_fitness(float(loss.item()))
|
||||
|
||||
# Call parasitic extractor
|
||||
if extractor.should_extract(step, float(loss.item())):
|
||||
metrics = {
|
||||
"loss": float(loss.item()),
|
||||
"sft_loss": float(sft_loss.item()),
|
||||
"optimizer": optimizer_name,
|
||||
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
|
||||
}
|
||||
extractor.extract_adapters(model, step, metrics)
|
||||
|
||||
# Tracking metrics
|
||||
elapsed = time.perf_counter() - start_time
|
||||
batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item())
|
||||
@@ -232,6 +255,10 @@ def train_run(
|
||||
f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}"
|
||||
)
|
||||
|
||||
# Save the extracted adapter library
|
||||
library_path = f"parasitic_adapters_{optimizer_name.lower()}_step{steps}.pt"
|
||||
extractor.save_library(library_path)
|
||||
|
||||
# 4. Post-Training Evaluation
|
||||
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
|
||||
post_eval = evaluate_model(model, tokenizer, device)
|
||||
|
||||
Reference in New Issue
Block a user