493 lines
18 KiB
Python
493 lines
18 KiB
Python
"""Parasitic QLoRA Adapter Extraction Engine.
|
||
|
||
Extracts low-rank LoRA adapters from weight deltas during full-parameter
|
||
fine-tuning at near-zero additional compute cost. Piggybacks on the existing
|
||
SVD computation in FCES's SpectralSensor.
|
||
|
||
Architecture:
|
||
During training, W₀ (base weights) is snapshotted at init.
|
||
At extraction points: ΔW = W_t - W₀
|
||
Truncated SVD: ΔW ≈ U_r·Σ_r·V_r^T = B·A (LoRA decomposition)
|
||
|
||
Compute overhead: < 0.5% of training FLOPs when using torch.svd_lowrank.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class QLoRAConfig:
|
||
"""Configuration for parasitic QLoRA extraction."""
|
||
|
||
min_rank: int = 8
|
||
max_rank: int = 64
|
||
explained_variance_threshold: float = 0.95
|
||
extraction_interval: int = 50
|
||
interesting_point_detection: bool = True
|
||
loss_plateau_window: int = 10
|
||
loss_plateau_threshold: float = 0.005
|
||
val_improvement_threshold: float = 0.01
|
||
min_weight_dims: int = 2
|
||
quantize_to_int8: bool = False
|
||
|
||
|
||
@dataclass
|
||
class LoRAMatrices:
|
||
"""Low-rank decomposition of a single layer's weight delta."""
|
||
|
||
layer_name: str
|
||
lora_B: torch.Tensor # d × r
|
||
lora_A: torch.Tensor # r × k
|
||
rank: int
|
||
explained_variance: float
|
||
singular_values: torch.Tensor
|
||
original_shape: tuple[int, ...]
|
||
|
||
|
||
@dataclass
|
||
class ExpertAdapter:
|
||
"""A complete LoRA adapter extracted from a training checkpoint."""
|
||
|
||
adapter_id: str
|
||
step: int
|
||
rank_per_layer: dict[str, int] = field(default_factory=dict)
|
||
avg_explained_variance: float = 0.0
|
||
|
||
# LoRA matrices per layer
|
||
layers: dict[str, LoRAMatrices] = field(default_factory=dict)
|
||
|
||
# Training metrics at extraction point
|
||
loss: float = 0.0
|
||
sft_loss: float = 0.0
|
||
val_score: float = 0.0
|
||
external_val_score: float | None = None
|
||
optimizer_type: str = ""
|
||
spectral_rank: float = 0.0
|
||
extraction_trigger: str = "periodic"
|
||
wall_clock_time: float = 0.0
|
||
|
||
# FCES-specific metadata
|
||
fces_population_fitness: float | None = None
|
||
fces_controller_id: int | None = None
|
||
|
||
# Domain tags
|
||
domain_tags: list[str] = field(default_factory=list)
|
||
|
||
def total_params(self) -> int:
|
||
"""Total number of parameters across all LoRA layers."""
|
||
total = 0
|
||
for lm in self.layers.values():
|
||
total += lm.lora_B.numel() + lm.lora_A.numel()
|
||
return total
|
||
|
||
def compression_ratio(self, original_params: int) -> float:
|
||
"""Compression ratio vs full model parameters."""
|
||
adapter_params = self.total_params()
|
||
if adapter_params == 0:
|
||
return float("inf")
|
||
return original_params / adapter_params
|
||
|
||
def to_state_dict(self) -> dict[str, Any]:
|
||
"""Serialize to a state dict for saving."""
|
||
state: dict[str, Any] = {
|
||
"adapter_id": self.adapter_id,
|
||
"step": self.step,
|
||
"loss": self.loss,
|
||
"sft_loss": self.sft_loss,
|
||
"val_score": self.val_score,
|
||
"external_val_score": self.external_val_score,
|
||
"optimizer_type": self.optimizer_type,
|
||
"spectral_rank": self.spectral_rank,
|
||
"extraction_trigger": self.extraction_trigger,
|
||
"wall_clock_time": self.wall_clock_time,
|
||
"fces_population_fitness": self.fces_population_fitness,
|
||
"fces_controller_id": self.fces_controller_id,
|
||
"domain_tags": self.domain_tags,
|
||
"avg_explained_variance": self.avg_explained_variance,
|
||
}
|
||
layers_state: dict[str, dict[str, Any]] = {}
|
||
for name, lm in self.layers.items():
|
||
layers_state[name] = {
|
||
"lora_B": lm.lora_B.cpu(),
|
||
"lora_A": lm.lora_A.cpu(),
|
||
"rank": lm.rank,
|
||
"explained_variance": lm.explained_variance,
|
||
"singular_values": lm.singular_values.cpu(),
|
||
"original_shape": lm.original_shape,
|
||
}
|
||
state["layers"] = layers_state
|
||
return state
|
||
|
||
|
||
class ParasiticQLoRAExtractor:
|
||
"""Extracts LoRA adapters parasitically during full-parameter fine-tuning.
|
||
|
||
The key insight: during training we already compute forward/backward passes
|
||
and (in FCES) SVD for spectral sensing. The weight delta ΔW = W_t - W₀ can
|
||
be decomposed into a low-rank LoRA adapter at near-zero marginal cost using
|
||
truncated SVD (torch.svd_lowrank).
|
||
"""
|
||
|
||
def __init__(self, config: QLoRAConfig | None = None) -> None:
|
||
self.config = config or QLoRAConfig()
|
||
self.base_weights: dict[str, torch.Tensor] = {}
|
||
self.adapter_library: list[ExpertAdapter] = []
|
||
self.loss_history: list[float] = []
|
||
self.val_history: list[float] = []
|
||
self._extraction_count: int = 0
|
||
self._total_extraction_time: float = 0.0
|
||
self._total_training_time: float = 0.0
|
||
|
||
def snapshot_base(self, model: nn.Module) -> None:
|
||
"""Snapshot base weights W₀ at training start. Called once."""
|
||
self.base_weights.clear()
|
||
for name, param in model.named_parameters():
|
||
if param.dim() >= self.config.min_weight_dims:
|
||
self.base_weights[name] = param.data.clone().cpu()
|
||
print(
|
||
f"[ParasiticQLoRA] Base snapshot: {len(self.base_weights)} "
|
||
f"weight matrices tracked"
|
||
)
|
||
|
||
def should_extract(
|
||
self,
|
||
step: int,
|
||
loss: float,
|
||
val_score: float = 0.0,
|
||
) -> bool:
|
||
"""Determines whether to extract an adapter at this step.
|
||
|
||
Triggers on:
|
||
1. Periodic interval (every N steps)
|
||
2. Loss plateau detection
|
||
3. Validation score improvement
|
||
"""
|
||
self.loss_history.append(loss)
|
||
if val_score > 0:
|
||
self.val_history.append(val_score)
|
||
|
||
# Always extract at step 0 (pre-training baseline)
|
||
if step == 0:
|
||
return True
|
||
|
||
# 1. Periodic extraction
|
||
if step % self.config.extraction_interval == 0:
|
||
return True
|
||
|
||
if not self.config.interesting_point_detection:
|
||
return False
|
||
|
||
# 2. Loss plateau detection
|
||
window = self.config.loss_plateau_window
|
||
if len(self.loss_history) >= window:
|
||
recent = self.loss_history[-window:]
|
||
loss_range = max(recent) - min(recent)
|
||
avg_loss = sum(recent) / len(recent)
|
||
if (
|
||
avg_loss > 0
|
||
and loss_range / avg_loss < self.config.loss_plateau_threshold
|
||
):
|
||
return True
|
||
|
||
# 3. Validation improvement
|
||
if len(self.val_history) >= 2:
|
||
improvement = self.val_history[-1] - self.val_history[-2]
|
||
if improvement > self.config.val_improvement_threshold:
|
||
return True
|
||
|
||
return False
|
||
|
||
def extract_adapters(
|
||
self,
|
||
model: nn.Module,
|
||
step: int,
|
||
metrics: dict[str, Any],
|
||
) -> ExpertAdapter:
|
||
"""Core parasitic extraction. Computes ΔW and decomposes via truncated SVD.
|
||
|
||
Algorithm per layer:
|
||
1. ΔW = W_t - W₀
|
||
2. U, Σ, V = svd_lowrank(ΔW, q=max_rank)
|
||
3. Select rank r via explained variance threshold
|
||
4. B = U_r · diag(√Σ_r), A = diag(√Σ_r) · V_r^T
|
||
"""
|
||
t0 = time.perf_counter()
|
||
self._extraction_count += 1
|
||
|
||
adapter_id = (
|
||
f"{metrics.get('optimizer', 'unknown')}_step{step}_{self._extraction_count}"
|
||
)
|
||
adapter = ExpertAdapter(
|
||
adapter_id=adapter_id,
|
||
step=step,
|
||
loss=metrics.get("loss", 0.0),
|
||
sft_loss=metrics.get("sft_loss", 0.0),
|
||
val_score=metrics.get("val_score", 0.0),
|
||
external_val_score=metrics.get("external_val_score"),
|
||
optimizer_type=metrics.get("optimizer", ""),
|
||
spectral_rank=metrics.get("spectral_rank", 0.0),
|
||
fces_population_fitness=metrics.get("fces_fitness"),
|
||
fces_controller_id=metrics.get("fces_controller_id"),
|
||
wall_clock_time=metrics.get("wall_clock_time", 0.0),
|
||
)
|
||
|
||
# Determine extraction trigger
|
||
trigger = "periodic"
|
||
if step == 0:
|
||
trigger = "initial"
|
||
elif len(self.loss_history) >= self.config.loss_plateau_window:
|
||
recent = self.loss_history[-self.config.loss_plateau_window :]
|
||
loss_range = max(recent) - min(recent)
|
||
avg_loss = sum(recent) / len(recent)
|
||
if (
|
||
avg_loss > 0
|
||
and loss_range / avg_loss < self.config.loss_plateau_threshold
|
||
):
|
||
trigger = "plateau"
|
||
if len(self.val_history) >= 2:
|
||
improvement = self.val_history[-1] - self.val_history[-2]
|
||
if improvement > self.config.val_improvement_threshold:
|
||
trigger = "val_improvement"
|
||
adapter.extraction_trigger = trigger
|
||
|
||
total_explained = 0.0
|
||
n_layers = 0
|
||
|
||
for name, param in model.named_parameters():
|
||
if name not in self.base_weights:
|
||
continue
|
||
if param.dim() < self.config.min_weight_dims:
|
||
continue
|
||
|
||
base_w = self.base_weights[name].to(param.device)
|
||
delta_w = param.data - base_w
|
||
|
||
# Skip near-zero deltas (layer hasn't changed)
|
||
delta_norm = delta_w.norm().item()
|
||
if delta_norm < 1e-7:
|
||
continue
|
||
|
||
# Reshape to 2D for SVD if needed (e.g., conv layers)
|
||
original_shape = delta_w.shape
|
||
if delta_w.dim() > 2:
|
||
delta_w = delta_w.reshape(delta_w.shape[0], -1)
|
||
|
||
d, k = delta_w.shape
|
||
max_rank = min(self.config.max_rank, d, k)
|
||
|
||
# Truncated SVD — O(d·k·r) instead of O(d·k·min(d,k))
|
||
try:
|
||
U, S, V = torch.svd_lowrank(delta_w.float(), q=max_rank)
|
||
except Exception:
|
||
# Fallback to full SVD if lowrank fails
|
||
U_full, S_full, Vh_full = torch.linalg.svd(
|
||
delta_w.float(), full_matrices=False
|
||
)
|
||
U = U_full[:, :max_rank]
|
||
S = S_full[:max_rank]
|
||
V = Vh_full[:max_rank, :].T
|
||
|
||
# Dynamic rank selection via explained variance
|
||
total_energy = (S**2).sum().item()
|
||
if total_energy < 1e-12:
|
||
continue
|
||
|
||
cumulative_energy = torch.cumsum(S**2, dim=0)
|
||
explained_ratios = cumulative_energy / total_energy
|
||
|
||
# Find rank where we exceed the threshold
|
||
rank = self.config.min_rank
|
||
for r in range(self.config.min_rank, max_rank + 1):
|
||
if (
|
||
r <= len(explained_ratios)
|
||
and explained_ratios[r - 1].item()
|
||
>= self.config.explained_variance_threshold
|
||
):
|
||
rank = r
|
||
break
|
||
else:
|
||
rank = max_rank
|
||
|
||
rank = min(rank, len(S))
|
||
explained_var = (
|
||
explained_ratios[rank - 1].item()
|
||
if rank <= len(explained_ratios)
|
||
else 1.0
|
||
)
|
||
|
||
# LoRA decomposition: B = U_r · √Σ_r, A = √Σ_r · V_r^T
|
||
sqrt_S = torch.sqrt(S[:rank])
|
||
lora_B = U[:, :rank] * sqrt_S.unsqueeze(0) # d × r
|
||
lora_A = V[:, :rank].T * sqrt_S.unsqueeze(1) # r × k
|
||
|
||
# Optional INT8 quantization
|
||
if self.config.quantize_to_int8:
|
||
lora_B = _quantize_int8(lora_B)
|
||
lora_A = _quantize_int8(lora_A)
|
||
|
||
lora_matrices = LoRAMatrices(
|
||
layer_name=name,
|
||
lora_B=lora_B.cpu().half(),
|
||
lora_A=lora_A.cpu().half(),
|
||
rank=rank,
|
||
explained_variance=explained_var,
|
||
singular_values=S[:rank].cpu(),
|
||
original_shape=original_shape,
|
||
)
|
||
|
||
adapter.layers[name] = lora_matrices
|
||
adapter.rank_per_layer[name] = rank
|
||
total_explained += explained_var
|
||
n_layers += 1
|
||
|
||
if n_layers > 0:
|
||
adapter.avg_explained_variance = total_explained / n_layers
|
||
|
||
self.adapter_library.append(adapter)
|
||
|
||
extraction_time = time.perf_counter() - t0
|
||
self._total_extraction_time += extraction_time
|
||
|
||
print(
|
||
f"[ParasiticQLoRA] Extracted adapter '{adapter_id}' | "
|
||
f"{n_layers} layers | avg_rank={sum(adapter.rank_per_layer.values()) / max(1, n_layers):.0f} | "
|
||
f"explained_var={adapter.avg_explained_variance:.3f} | "
|
||
f"params={adapter.total_params():,} | "
|
||
f"trigger={trigger} | time={extraction_time:.3f}s"
|
||
)
|
||
|
||
return adapter
|
||
|
||
def get_overhead_percentage(self) -> float:
|
||
"""Returns the extraction overhead as % of total training time."""
|
||
if self._total_training_time <= 0:
|
||
return 0.0
|
||
return (self._total_extraction_time / self._total_training_time) * 100.0
|
||
|
||
def update_training_time(self, elapsed: float) -> None:
|
||
"""Track total training time for overhead calculation."""
|
||
self._total_training_time = elapsed
|
||
|
||
def get_best_adapter(self) -> ExpertAdapter | None:
|
||
"""Returns the adapter with the lowest loss."""
|
||
if not self.adapter_library:
|
||
return None
|
||
return min(self.adapter_library, key=lambda a: a.loss)
|
||
|
||
def get_library_summary(self) -> dict[str, Any]:
|
||
"""Returns a summary of the adapter library."""
|
||
if not self.adapter_library:
|
||
return {"count": 0}
|
||
|
||
return {
|
||
"count": len(self.adapter_library),
|
||
"total_extraction_time_s": self._total_extraction_time,
|
||
"overhead_pct": self.get_overhead_percentage(),
|
||
"adapters": [
|
||
{
|
||
"id": a.adapter_id,
|
||
"step": a.step,
|
||
"loss": a.loss,
|
||
"n_layers": len(a.layers),
|
||
"avg_rank": sum(a.rank_per_layer.values())
|
||
/ max(1, len(a.rank_per_layer)),
|
||
"explained_var": a.avg_explained_variance,
|
||
"params": a.total_params(),
|
||
"trigger": a.extraction_trigger,
|
||
}
|
||
for a in self.adapter_library
|
||
],
|
||
}
|
||
|
||
def save_library(self, path: str) -> None:
|
||
"""Saves the entire adapter library to disk."""
|
||
state = {
|
||
"config": {
|
||
"min_rank": self.config.min_rank,
|
||
"max_rank": self.config.max_rank,
|
||
"explained_variance_threshold": self.config.explained_variance_threshold,
|
||
},
|
||
"adapters": [a.to_state_dict() for a in self.adapter_library],
|
||
"extraction_stats": {
|
||
"total_extractions": self._extraction_count,
|
||
"total_extraction_time_s": self._total_extraction_time,
|
||
"overhead_pct": self.get_overhead_percentage(),
|
||
},
|
||
}
|
||
torch.save(state, path)
|
||
print(
|
||
f"[ParasiticQLoRA] Library saved: {path} ({len(self.adapter_library)} adapters)"
|
||
)
|
||
|
||
@staticmethod
|
||
def load_library(path: str) -> list[ExpertAdapter]:
|
||
"""Loads adapters from a saved library file."""
|
||
state = torch.load(path, map_location="cpu", weights_only=False)
|
||
adapters: list[ExpertAdapter] = []
|
||
for adapter_state in state["adapters"]:
|
||
adapter = ExpertAdapter(
|
||
adapter_id=adapter_state["adapter_id"],
|
||
step=adapter_state["step"],
|
||
loss=adapter_state["loss"],
|
||
sft_loss=adapter_state["sft_loss"],
|
||
val_score=adapter_state.get("val_score", 0.0),
|
||
optimizer_type=adapter_state.get("optimizer_type", ""),
|
||
spectral_rank=adapter_state.get("spectral_rank", 0.0),
|
||
extraction_trigger=adapter_state.get("extraction_trigger", "unknown"),
|
||
avg_explained_variance=adapter_state.get("avg_explained_variance", 0.0),
|
||
)
|
||
for name, layer_state in adapter_state.get("layers", {}).items():
|
||
adapter.layers[name] = LoRAMatrices(
|
||
layer_name=name,
|
||
lora_B=layer_state["lora_B"],
|
||
lora_A=layer_state["lora_A"],
|
||
rank=layer_state["rank"],
|
||
explained_variance=layer_state["explained_variance"],
|
||
singular_values=layer_state["singular_values"],
|
||
original_shape=tuple(layer_state["original_shape"]),
|
||
)
|
||
adapter.rank_per_layer[name] = layer_state["rank"]
|
||
adapters.append(adapter)
|
||
return adapters
|
||
|
||
|
||
def apply_adapter(
|
||
model: nn.Module,
|
||
adapter: ExpertAdapter,
|
||
scale: float = 1.0,
|
||
) -> None:
|
||
"""Applies a LoRA adapter to a model (adds the low-rank delta).
|
||
|
||
W_new = W_base + scale * B @ A
|
||
"""
|
||
with torch.no_grad():
|
||
for name, param in model.named_parameters():
|
||
if name not in adapter.layers:
|
||
continue
|
||
lm = adapter.layers[name]
|
||
# Reconstruct: ΔW = B @ A
|
||
delta = lm.lora_B.to(param.device, param.dtype) @ lm.lora_A.to(
|
||
param.device, param.dtype
|
||
)
|
||
# Reshape if needed
|
||
if delta.shape != param.shape:
|
||
delta = delta.reshape(param.shape)
|
||
param.data.add_(delta * scale)
|
||
|
||
|
||
def _quantize_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||
"""Simple absmax INT8 quantization (dequantized back to float for storage)."""
|
||
abs_max = tensor.abs().max()
|
||
if abs_max < 1e-10:
|
||
return tensor
|
||
scale = 127.0 / abs_max
|
||
quantized = (tensor * scale).round().clamp(-127, 127)
|
||
return quantized / scale
|