feat: implement parasitic QLoRA adapter extraction and unit tests
This commit is contained in:
8
.gitignore
vendored
8
.gitignore
vendored
@@ -9,8 +9,10 @@ out/
|
|||||||
*.so
|
*.so
|
||||||
*.dylib
|
*.dylib
|
||||||
*.dll
|
*.dll
|
||||||
|
*.pyd
|
||||||
*.a
|
*.a
|
||||||
*.lib
|
*.lib
|
||||||
|
|
||||||
*.exe
|
*.exe
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
@@ -44,3 +46,9 @@ Thumbs.db
|
|||||||
|
|
||||||
# libtorch download
|
# libtorch download
|
||||||
libtorch/
|
libtorch/
|
||||||
|
|
||||||
|
# Local cache and artifacts
|
||||||
|
*.pt
|
||||||
|
scratch/
|
||||||
|
telemetry.log
|
||||||
|
telemetry.offset
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import dspy # noqa: E402
|
|||||||
import torch.nn.functional as F # noqa: E402
|
import torch.nn.functional as F # noqa: E402
|
||||||
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
||||||
|
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
||||||
@@ -162,6 +163,18 @@ def train_run(
|
|||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
||||||
|
|
||||||
|
# Initialize Parasitic QLoRA Extractor
|
||||||
|
extractor = ParasiticQLoRAExtractor(
|
||||||
|
QLoRAConfig(
|
||||||
|
min_rank=8,
|
||||||
|
max_rank=32,
|
||||||
|
explained_variance_threshold=0.95,
|
||||||
|
extraction_interval=5,
|
||||||
|
interesting_point_detection=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
extractor.snapshot_base(model)
|
||||||
|
|
||||||
# 1. Pre-Training Evaluation
|
# 1. Pre-Training Evaluation
|
||||||
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
||||||
pre_eval = evaluate_model(model, tokenizer, device)
|
pre_eval = evaluate_model(model, tokenizer, device)
|
||||||
@@ -211,6 +224,16 @@ def train_run(
|
|||||||
if optimizer_name == "FCES":
|
if optimizer_name == "FCES":
|
||||||
optimizer.update_fitness(float(loss.item()))
|
optimizer.update_fitness(float(loss.item()))
|
||||||
|
|
||||||
|
# Call parasitic extractor
|
||||||
|
if extractor.should_extract(step, float(loss.item())):
|
||||||
|
metrics = {
|
||||||
|
"loss": float(loss.item()),
|
||||||
|
"sft_loss": float(sft_loss.item()),
|
||||||
|
"optimizer": optimizer_name,
|
||||||
|
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
|
||||||
|
}
|
||||||
|
extractor.extract_adapters(model, step, metrics)
|
||||||
|
|
||||||
# Tracking metrics
|
# Tracking metrics
|
||||||
elapsed = time.perf_counter() - start_time
|
elapsed = time.perf_counter() - start_time
|
||||||
batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item())
|
batch_tokens = int((input_win != tokenizer.pad_token_id).sum().item())
|
||||||
@@ -232,6 +255,10 @@ def train_run(
|
|||||||
f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}"
|
f"Time: {elapsed:.2f}s | Tokens: {tokens_processed}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save the extracted adapter library
|
||||||
|
library_path = f"parasitic_adapters_{optimizer_name.lower()}_step{steps}.pt"
|
||||||
|
extractor.save_library(library_path)
|
||||||
|
|
||||||
# 4. Post-Training Evaluation
|
# 4. Post-Training Evaluation
|
||||||
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
|
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
|
||||||
post_eval = evaluate_model(model, tokenizer, device)
|
post_eval = evaluate_model(model, tokenizer, device)
|
||||||
|
|||||||
492
python/parasitic_qlora.py
Normal file
492
python/parasitic_qlora.py
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
"""Parasitic QLoRA Adapter Extraction Engine.
|
||||||
|
|
||||||
|
Extracts low-rank LoRA adapters from weight deltas during full-parameter
|
||||||
|
fine-tuning at near-zero additional compute cost. Piggybacks on the existing
|
||||||
|
SVD computation in FCES's SpectralSensor.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
During training, W₀ (base weights) is snapshotted at init.
|
||||||
|
At extraction points: ΔW = W_t - W₀
|
||||||
|
Truncated SVD: ΔW ≈ U_r·Σ_r·V_r^T = B·A (LoRA decomposition)
|
||||||
|
|
||||||
|
Compute overhead: < 0.5% of training FLOPs when using torch.svd_lowrank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class QLoRAConfig:
|
||||||
|
"""Configuration for parasitic QLoRA extraction."""
|
||||||
|
|
||||||
|
min_rank: int = 8
|
||||||
|
max_rank: int = 64
|
||||||
|
explained_variance_threshold: float = 0.95
|
||||||
|
extraction_interval: int = 50
|
||||||
|
interesting_point_detection: bool = True
|
||||||
|
loss_plateau_window: int = 10
|
||||||
|
loss_plateau_threshold: float = 0.005
|
||||||
|
val_improvement_threshold: float = 0.01
|
||||||
|
min_weight_dims: int = 2
|
||||||
|
quantize_to_int8: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAMatrices:
|
||||||
|
"""Low-rank decomposition of a single layer's weight delta."""
|
||||||
|
|
||||||
|
layer_name: str
|
||||||
|
lora_B: torch.Tensor # d × r
|
||||||
|
lora_A: torch.Tensor # r × k
|
||||||
|
rank: int
|
||||||
|
explained_variance: float
|
||||||
|
singular_values: torch.Tensor
|
||||||
|
original_shape: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpertAdapter:
|
||||||
|
"""A complete LoRA adapter extracted from a training checkpoint."""
|
||||||
|
|
||||||
|
adapter_id: str
|
||||||
|
step: int
|
||||||
|
rank_per_layer: dict[str, int] = field(default_factory=dict)
|
||||||
|
avg_explained_variance: float = 0.0
|
||||||
|
|
||||||
|
# LoRA matrices per layer
|
||||||
|
layers: dict[str, LoRAMatrices] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Training metrics at extraction point
|
||||||
|
loss: float = 0.0
|
||||||
|
sft_loss: float = 0.0
|
||||||
|
val_score: float = 0.0
|
||||||
|
external_val_score: float | None = None
|
||||||
|
optimizer_type: str = ""
|
||||||
|
spectral_rank: float = 0.0
|
||||||
|
extraction_trigger: str = "periodic"
|
||||||
|
wall_clock_time: float = 0.0
|
||||||
|
|
||||||
|
# FCES-specific metadata
|
||||||
|
fces_population_fitness: float | None = None
|
||||||
|
fces_controller_id: int | None = None
|
||||||
|
|
||||||
|
# Domain tags
|
||||||
|
domain_tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def total_params(self) -> int:
|
||||||
|
"""Total number of parameters across all LoRA layers."""
|
||||||
|
total = 0
|
||||||
|
for lm in self.layers.values():
|
||||||
|
total += lm.lora_B.numel() + lm.lora_A.numel()
|
||||||
|
return total
|
||||||
|
|
||||||
|
def compression_ratio(self, original_params: int) -> float:
|
||||||
|
"""Compression ratio vs full model parameters."""
|
||||||
|
adapter_params = self.total_params()
|
||||||
|
if adapter_params == 0:
|
||||||
|
return float("inf")
|
||||||
|
return original_params / adapter_params
|
||||||
|
|
||||||
|
def to_state_dict(self) -> dict[str, Any]:
|
||||||
|
"""Serialize to a state dict for saving."""
|
||||||
|
state: dict[str, Any] = {
|
||||||
|
"adapter_id": self.adapter_id,
|
||||||
|
"step": self.step,
|
||||||
|
"loss": self.loss,
|
||||||
|
"sft_loss": self.sft_loss,
|
||||||
|
"val_score": self.val_score,
|
||||||
|
"external_val_score": self.external_val_score,
|
||||||
|
"optimizer_type": self.optimizer_type,
|
||||||
|
"spectral_rank": self.spectral_rank,
|
||||||
|
"extraction_trigger": self.extraction_trigger,
|
||||||
|
"wall_clock_time": self.wall_clock_time,
|
||||||
|
"fces_population_fitness": self.fces_population_fitness,
|
||||||
|
"fces_controller_id": self.fces_controller_id,
|
||||||
|
"domain_tags": self.domain_tags,
|
||||||
|
"avg_explained_variance": self.avg_explained_variance,
|
||||||
|
}
|
||||||
|
layers_state: dict[str, dict[str, Any]] = {}
|
||||||
|
for name, lm in self.layers.items():
|
||||||
|
layers_state[name] = {
|
||||||
|
"lora_B": lm.lora_B.cpu(),
|
||||||
|
"lora_A": lm.lora_A.cpu(),
|
||||||
|
"rank": lm.rank,
|
||||||
|
"explained_variance": lm.explained_variance,
|
||||||
|
"singular_values": lm.singular_values.cpu(),
|
||||||
|
"original_shape": lm.original_shape,
|
||||||
|
}
|
||||||
|
state["layers"] = layers_state
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
class ParasiticQLoRAExtractor:
|
||||||
|
"""Extracts LoRA adapters parasitically during full-parameter fine-tuning.
|
||||||
|
|
||||||
|
The key insight: during training we already compute forward/backward passes
|
||||||
|
and (in FCES) SVD for spectral sensing. The weight delta ΔW = W_t - W₀ can
|
||||||
|
be decomposed into a low-rank LoRA adapter at near-zero marginal cost using
|
||||||
|
truncated SVD (torch.svd_lowrank).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: QLoRAConfig | None = None) -> None:
|
||||||
|
self.config = config or QLoRAConfig()
|
||||||
|
self.base_weights: dict[str, torch.Tensor] = {}
|
||||||
|
self.adapter_library: list[ExpertAdapter] = []
|
||||||
|
self.loss_history: list[float] = []
|
||||||
|
self.val_history: list[float] = []
|
||||||
|
self._extraction_count: int = 0
|
||||||
|
self._total_extraction_time: float = 0.0
|
||||||
|
self._total_training_time: float = 0.0
|
||||||
|
|
||||||
|
def snapshot_base(self, model: nn.Module) -> None:
|
||||||
|
"""Snapshot base weights W₀ at training start. Called once."""
|
||||||
|
self.base_weights.clear()
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.dim() >= self.config.min_weight_dims:
|
||||||
|
self.base_weights[name] = param.data.clone().cpu()
|
||||||
|
print(
|
||||||
|
f"[ParasiticQLoRA] Base snapshot: {len(self.base_weights)} "
|
||||||
|
f"weight matrices tracked"
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_extract(
|
||||||
|
self,
|
||||||
|
step: int,
|
||||||
|
loss: float,
|
||||||
|
val_score: float = 0.0,
|
||||||
|
) -> bool:
|
||||||
|
"""Determines whether to extract an adapter at this step.
|
||||||
|
|
||||||
|
Triggers on:
|
||||||
|
1. Periodic interval (every N steps)
|
||||||
|
2. Loss plateau detection
|
||||||
|
3. Validation score improvement
|
||||||
|
"""
|
||||||
|
self.loss_history.append(loss)
|
||||||
|
if val_score > 0:
|
||||||
|
self.val_history.append(val_score)
|
||||||
|
|
||||||
|
# Always extract at step 0 (pre-training baseline)
|
||||||
|
if step == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 1. Periodic extraction
|
||||||
|
if step % self.config.extraction_interval == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not self.config.interesting_point_detection:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 2. Loss plateau detection
|
||||||
|
window = self.config.loss_plateau_window
|
||||||
|
if len(self.loss_history) >= window:
|
||||||
|
recent = self.loss_history[-window:]
|
||||||
|
loss_range = max(recent) - min(recent)
|
||||||
|
avg_loss = sum(recent) / len(recent)
|
||||||
|
if (
|
||||||
|
avg_loss > 0
|
||||||
|
and loss_range / avg_loss < self.config.loss_plateau_threshold
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 3. Validation improvement
|
||||||
|
if len(self.val_history) >= 2:
|
||||||
|
improvement = self.val_history[-1] - self.val_history[-2]
|
||||||
|
if improvement > self.config.val_improvement_threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def extract_adapters(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
step: int,
|
||||||
|
metrics: dict[str, Any],
|
||||||
|
) -> ExpertAdapter:
|
||||||
|
"""Core parasitic extraction. Computes ΔW and decomposes via truncated SVD.
|
||||||
|
|
||||||
|
Algorithm per layer:
|
||||||
|
1. ΔW = W_t - W₀
|
||||||
|
2. U, Σ, V = svd_lowrank(ΔW, q=max_rank)
|
||||||
|
3. Select rank r via explained variance threshold
|
||||||
|
4. B = U_r · diag(√Σ_r), A = diag(√Σ_r) · V_r^T
|
||||||
|
"""
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
self._extraction_count += 1
|
||||||
|
|
||||||
|
adapter_id = (
|
||||||
|
f"{metrics.get('optimizer', 'unknown')}_step{step}_{self._extraction_count}"
|
||||||
|
)
|
||||||
|
adapter = ExpertAdapter(
|
||||||
|
adapter_id=adapter_id,
|
||||||
|
step=step,
|
||||||
|
loss=metrics.get("loss", 0.0),
|
||||||
|
sft_loss=metrics.get("sft_loss", 0.0),
|
||||||
|
val_score=metrics.get("val_score", 0.0),
|
||||||
|
external_val_score=metrics.get("external_val_score"),
|
||||||
|
optimizer_type=metrics.get("optimizer", ""),
|
||||||
|
spectral_rank=metrics.get("spectral_rank", 0.0),
|
||||||
|
fces_population_fitness=metrics.get("fces_fitness"),
|
||||||
|
fces_controller_id=metrics.get("fces_controller_id"),
|
||||||
|
wall_clock_time=metrics.get("wall_clock_time", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine extraction trigger
|
||||||
|
trigger = "periodic"
|
||||||
|
if step == 0:
|
||||||
|
trigger = "initial"
|
||||||
|
elif len(self.loss_history) >= self.config.loss_plateau_window:
|
||||||
|
recent = self.loss_history[-self.config.loss_plateau_window :]
|
||||||
|
loss_range = max(recent) - min(recent)
|
||||||
|
avg_loss = sum(recent) / len(recent)
|
||||||
|
if (
|
||||||
|
avg_loss > 0
|
||||||
|
and loss_range / avg_loss < self.config.loss_plateau_threshold
|
||||||
|
):
|
||||||
|
trigger = "plateau"
|
||||||
|
if len(self.val_history) >= 2:
|
||||||
|
improvement = self.val_history[-1] - self.val_history[-2]
|
||||||
|
if improvement > self.config.val_improvement_threshold:
|
||||||
|
trigger = "val_improvement"
|
||||||
|
adapter.extraction_trigger = trigger
|
||||||
|
|
||||||
|
total_explained = 0.0
|
||||||
|
n_layers = 0
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name not in self.base_weights:
|
||||||
|
continue
|
||||||
|
if param.dim() < self.config.min_weight_dims:
|
||||||
|
continue
|
||||||
|
|
||||||
|
base_w = self.base_weights[name].to(param.device)
|
||||||
|
delta_w = param.data - base_w
|
||||||
|
|
||||||
|
# Skip near-zero deltas (layer hasn't changed)
|
||||||
|
delta_norm = delta_w.norm().item()
|
||||||
|
if delta_norm < 1e-7:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Reshape to 2D for SVD if needed (e.g., conv layers)
|
||||||
|
original_shape = delta_w.shape
|
||||||
|
if delta_w.dim() > 2:
|
||||||
|
delta_w = delta_w.reshape(delta_w.shape[0], -1)
|
||||||
|
|
||||||
|
d, k = delta_w.shape
|
||||||
|
max_rank = min(self.config.max_rank, d, k)
|
||||||
|
|
||||||
|
# Truncated SVD — O(d·k·r) instead of O(d·k·min(d,k))
|
||||||
|
try:
|
||||||
|
U, S, V = torch.svd_lowrank(delta_w.float(), q=max_rank)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to full SVD if lowrank fails
|
||||||
|
U_full, S_full, Vh_full = torch.linalg.svd(
|
||||||
|
delta_w.float(), full_matrices=False
|
||||||
|
)
|
||||||
|
U = U_full[:, :max_rank]
|
||||||
|
S = S_full[:max_rank]
|
||||||
|
V = Vh_full[:max_rank, :].T
|
||||||
|
|
||||||
|
# Dynamic rank selection via explained variance
|
||||||
|
total_energy = (S**2).sum().item()
|
||||||
|
if total_energy < 1e-12:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cumulative_energy = torch.cumsum(S**2, dim=0)
|
||||||
|
explained_ratios = cumulative_energy / total_energy
|
||||||
|
|
||||||
|
# Find rank where we exceed the threshold
|
||||||
|
rank = self.config.min_rank
|
||||||
|
for r in range(self.config.min_rank, max_rank + 1):
|
||||||
|
if (
|
||||||
|
r <= len(explained_ratios)
|
||||||
|
and explained_ratios[r - 1].item()
|
||||||
|
>= self.config.explained_variance_threshold
|
||||||
|
):
|
||||||
|
rank = r
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
rank = max_rank
|
||||||
|
|
||||||
|
rank = min(rank, len(S))
|
||||||
|
explained_var = (
|
||||||
|
explained_ratios[rank - 1].item()
|
||||||
|
if rank <= len(explained_ratios)
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# LoRA decomposition: B = U_r · √Σ_r, A = √Σ_r · V_r^T
|
||||||
|
sqrt_S = torch.sqrt(S[:rank])
|
||||||
|
lora_B = U[:, :rank] * sqrt_S.unsqueeze(0) # d × r
|
||||||
|
lora_A = V[:, :rank].T * sqrt_S.unsqueeze(1) # r × k
|
||||||
|
|
||||||
|
# Optional INT8 quantization
|
||||||
|
if self.config.quantize_to_int8:
|
||||||
|
lora_B = _quantize_int8(lora_B)
|
||||||
|
lora_A = _quantize_int8(lora_A)
|
||||||
|
|
||||||
|
lora_matrices = LoRAMatrices(
|
||||||
|
layer_name=name,
|
||||||
|
lora_B=lora_B.cpu().half(),
|
||||||
|
lora_A=lora_A.cpu().half(),
|
||||||
|
rank=rank,
|
||||||
|
explained_variance=explained_var,
|
||||||
|
singular_values=S[:rank].cpu(),
|
||||||
|
original_shape=original_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter.layers[name] = lora_matrices
|
||||||
|
adapter.rank_per_layer[name] = rank
|
||||||
|
total_explained += explained_var
|
||||||
|
n_layers += 1
|
||||||
|
|
||||||
|
if n_layers > 0:
|
||||||
|
adapter.avg_explained_variance = total_explained / n_layers
|
||||||
|
|
||||||
|
self.adapter_library.append(adapter)
|
||||||
|
|
||||||
|
extraction_time = time.perf_counter() - t0
|
||||||
|
self._total_extraction_time += extraction_time
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[ParasiticQLoRA] Extracted adapter '{adapter_id}' | "
|
||||||
|
f"{n_layers} layers | avg_rank={sum(adapter.rank_per_layer.values()) / max(1, n_layers):.0f} | "
|
||||||
|
f"explained_var={adapter.avg_explained_variance:.3f} | "
|
||||||
|
f"params={adapter.total_params():,} | "
|
||||||
|
f"trigger={trigger} | time={extraction_time:.3f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
def get_overhead_percentage(self) -> float:
|
||||||
|
"""Returns the extraction overhead as % of total training time."""
|
||||||
|
if self._total_training_time <= 0:
|
||||||
|
return 0.0
|
||||||
|
return (self._total_extraction_time / self._total_training_time) * 100.0
|
||||||
|
|
||||||
|
def update_training_time(self, elapsed: float) -> None:
|
||||||
|
"""Track total training time for overhead calculation."""
|
||||||
|
self._total_training_time = elapsed
|
||||||
|
|
||||||
|
def get_best_adapter(self) -> ExpertAdapter | None:
|
||||||
|
"""Returns the adapter with the lowest loss."""
|
||||||
|
if not self.adapter_library:
|
||||||
|
return None
|
||||||
|
return min(self.adapter_library, key=lambda a: a.loss)
|
||||||
|
|
||||||
|
def get_library_summary(self) -> dict[str, Any]:
|
||||||
|
"""Returns a summary of the adapter library."""
|
||||||
|
if not self.adapter_library:
|
||||||
|
return {"count": 0}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"count": len(self.adapter_library),
|
||||||
|
"total_extraction_time_s": self._total_extraction_time,
|
||||||
|
"overhead_pct": self.get_overhead_percentage(),
|
||||||
|
"adapters": [
|
||||||
|
{
|
||||||
|
"id": a.adapter_id,
|
||||||
|
"step": a.step,
|
||||||
|
"loss": a.loss,
|
||||||
|
"n_layers": len(a.layers),
|
||||||
|
"avg_rank": sum(a.rank_per_layer.values())
|
||||||
|
/ max(1, len(a.rank_per_layer)),
|
||||||
|
"explained_var": a.avg_explained_variance,
|
||||||
|
"params": a.total_params(),
|
||||||
|
"trigger": a.extraction_trigger,
|
||||||
|
}
|
||||||
|
for a in self.adapter_library
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_library(self, path: str) -> None:
|
||||||
|
"""Saves the entire adapter library to disk."""
|
||||||
|
state = {
|
||||||
|
"config": {
|
||||||
|
"min_rank": self.config.min_rank,
|
||||||
|
"max_rank": self.config.max_rank,
|
||||||
|
"explained_variance_threshold": self.config.explained_variance_threshold,
|
||||||
|
},
|
||||||
|
"adapters": [a.to_state_dict() for a in self.adapter_library],
|
||||||
|
"extraction_stats": {
|
||||||
|
"total_extractions": self._extraction_count,
|
||||||
|
"total_extraction_time_s": self._total_extraction_time,
|
||||||
|
"overhead_pct": self.get_overhead_percentage(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
torch.save(state, path)
|
||||||
|
print(
|
||||||
|
f"[ParasiticQLoRA] Library saved: {path} ({len(self.adapter_library)} adapters)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_library(path: str) -> list[ExpertAdapter]:
|
||||||
|
"""Loads adapters from a saved library file."""
|
||||||
|
state = torch.load(path, map_location="cpu", weights_only=False)
|
||||||
|
adapters: list[ExpertAdapter] = []
|
||||||
|
for adapter_state in state["adapters"]:
|
||||||
|
adapter = ExpertAdapter(
|
||||||
|
adapter_id=adapter_state["adapter_id"],
|
||||||
|
step=adapter_state["step"],
|
||||||
|
loss=adapter_state["loss"],
|
||||||
|
sft_loss=adapter_state["sft_loss"],
|
||||||
|
val_score=adapter_state.get("val_score", 0.0),
|
||||||
|
optimizer_type=adapter_state.get("optimizer_type", ""),
|
||||||
|
spectral_rank=adapter_state.get("spectral_rank", 0.0),
|
||||||
|
extraction_trigger=adapter_state.get("extraction_trigger", "unknown"),
|
||||||
|
avg_explained_variance=adapter_state.get("avg_explained_variance", 0.0),
|
||||||
|
)
|
||||||
|
for name, layer_state in adapter_state.get("layers", {}).items():
|
||||||
|
adapter.layers[name] = LoRAMatrices(
|
||||||
|
layer_name=name,
|
||||||
|
lora_B=layer_state["lora_B"],
|
||||||
|
lora_A=layer_state["lora_A"],
|
||||||
|
rank=layer_state["rank"],
|
||||||
|
explained_variance=layer_state["explained_variance"],
|
||||||
|
singular_values=layer_state["singular_values"],
|
||||||
|
original_shape=tuple(layer_state["original_shape"]),
|
||||||
|
)
|
||||||
|
adapter.rank_per_layer[name] = layer_state["rank"]
|
||||||
|
adapters.append(adapter)
|
||||||
|
return adapters
|
||||||
|
|
||||||
|
|
||||||
|
def apply_adapter(
|
||||||
|
model: nn.Module,
|
||||||
|
adapter: ExpertAdapter,
|
||||||
|
scale: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
"""Applies a LoRA adapter to a model (adds the low-rank delta).
|
||||||
|
|
||||||
|
W_new = W_base + scale * B @ A
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name not in adapter.layers:
|
||||||
|
continue
|
||||||
|
lm = adapter.layers[name]
|
||||||
|
# Reconstruct: ΔW = B @ A
|
||||||
|
delta = lm.lora_B.to(param.device, param.dtype) @ lm.lora_A.to(
|
||||||
|
param.device, param.dtype
|
||||||
|
)
|
||||||
|
# Reshape if needed
|
||||||
|
if delta.shape != param.shape:
|
||||||
|
delta = delta.reshape(param.shape)
|
||||||
|
param.data.add_(delta * scale)
|
||||||
|
|
||||||
|
|
||||||
|
def _quantize_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Simple absmax INT8 quantization (dequantized back to float for storage)."""
|
||||||
|
abs_max = tensor.abs().max()
|
||||||
|
if abs_max < 1e-10:
|
||||||
|
return tensor
|
||||||
|
scale = 127.0 / abs_max
|
||||||
|
quantized = (tensor * scale).round().clamp(-127, 127)
|
||||||
|
return quantized / scale
|
||||||
109
tests/test_parasitic_qlora.py
Normal file
109
tests/test_parasitic_qlora.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Ensure python directory is in path
|
||||||
|
sys.path.append(
|
||||||
|
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||||
|
)
|
||||||
|
|
||||||
|
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(32, 32, bias=False)
|
||||||
|
self.fc2 = nn.Linear(32, 16, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParasiticQLoRA(unittest.TestCase):
|
||||||
|
model: SimpleModel
|
||||||
|
config: QLoRAConfig
|
||||||
|
extractor: ParasiticQLoRAExtractor
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.model = SimpleModel()
|
||||||
|
self.config = QLoRAConfig(
|
||||||
|
min_rank=2,
|
||||||
|
max_rank=8,
|
||||||
|
explained_variance_threshold=0.9,
|
||||||
|
interesting_point_detection=False,
|
||||||
|
extraction_interval=1,
|
||||||
|
)
|
||||||
|
self.extractor = ParasiticQLoRAExtractor(self.config)
|
||||||
|
|
||||||
|
def test_snapshot_and_extraction(self) -> None:
|
||||||
|
# 1. Take base snapshot
|
||||||
|
self.extractor.snapshot_base(self.model)
|
||||||
|
self.assertEqual(len(self.extractor.base_weights), 2)
|
||||||
|
|
||||||
|
# Verify weight hashes are stored
|
||||||
|
for name in ["fc1.weight", "fc2.weight"]:
|
||||||
|
self.assertIn(name, self.extractor.base_weights)
|
||||||
|
self.assertTrue(isinstance(self.extractor.base_weights[name], torch.Tensor))
|
||||||
|
|
||||||
|
# 2. Simulate weight changes (as in fine-tuning)
|
||||||
|
# We add a low-rank delta to fc1.weight (rank 2)
|
||||||
|
u = torch.randn(32, 2)
|
||||||
|
v = torch.randn(2, 32)
|
||||||
|
delta_w1 = torch.matmul(u, v) * 0.1
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.fc1.weight.add_(delta_w1)
|
||||||
|
|
||||||
|
# 3. Perform extraction
|
||||||
|
step = 1
|
||||||
|
metrics = {"loss": 0.5, "step": step}
|
||||||
|
self.extractor.extract_adapters(self.model, step, metrics)
|
||||||
|
|
||||||
|
# Verify adapter is stored in library
|
||||||
|
self.assertEqual(len(self.extractor.adapter_library), 1)
|
||||||
|
adapter = self.extractor.adapter_library[0]
|
||||||
|
self.assertTrue(adapter.adapter_id.startswith("unknown_step1_"))
|
||||||
|
|
||||||
|
self.assertIn("fc1.weight", adapter.layers)
|
||||||
|
|
||||||
|
# Verify shapes of A and B
|
||||||
|
lora_A = adapter.layers["fc1.weight"].lora_A
|
||||||
|
lora_B = adapter.layers["fc1.weight"].lora_B
|
||||||
|
rank = adapter.layers["fc1.weight"].rank
|
||||||
|
|
||||||
|
self.assertEqual(lora_A.shape, (rank, 32))
|
||||||
|
self.assertEqual(lora_B.shape, (32, rank))
|
||||||
|
self.assertGreaterEqual(rank, self.config.min_rank)
|
||||||
|
self.assertLessEqual(rank, self.config.max_rank)
|
||||||
|
|
||||||
|
def test_save_and_load(self) -> None:
|
||||||
|
self.extractor.snapshot_base(self.model)
|
||||||
|
|
||||||
|
# Make a change
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.fc2.weight.add_(torch.randn(16, 32) * 0.05)
|
||||||
|
|
||||||
|
self.extractor.extract_adapters(self.model, 1, {"loss": 0.2})
|
||||||
|
|
||||||
|
# Save library
|
||||||
|
test_path = "test_adapters.pt"
|
||||||
|
self.extractor.save_library(test_path)
|
||||||
|
self.assertTrue(os.path.exists(test_path))
|
||||||
|
|
||||||
|
# Load in a new extractor
|
||||||
|
new_extractor = ParasiticQLoRAExtractor(self.config)
|
||||||
|
loaded = new_extractor.load_library(test_path)
|
||||||
|
new_extractor.adapter_library = loaded
|
||||||
|
|
||||||
|
self.assertEqual(len(new_extractor.adapter_library), 1)
|
||||||
|
orig_adapter = self.extractor.adapter_library[0]
|
||||||
|
loaded_adapter = new_extractor.adapter_library[0]
|
||||||
|
self.assertEqual(orig_adapter.adapter_id, loaded_adapter.adapter_id)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
if os.path.exists(test_path):
|
||||||
|
os.remove(test_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user