Files
FCES-native/python/parasitic_qlora.py

493 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Parasitic QLoRA Adapter Extraction Engine.
Extracts low-rank LoRA adapters from weight deltas during full-parameter
fine-tuning at near-zero additional compute cost. Piggybacks on the existing
SVD computation in FCES's SpectralSensor.
Architecture:
During training, W₀ (base weights) is snapshotted at init.
At extraction points: ΔW = W_t - W₀
Truncated SVD: ΔW ≈ U_r·Σ_r·V_r^T = B·A (LoRA decomposition)
Compute overhead: < 0.5% of training FLOPs when using torch.svd_lowrank.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.nn as nn
@dataclass(frozen=True)
class QLoRAConfig:
"""Configuration for parasitic QLoRA extraction."""
min_rank: int = 8
max_rank: int = 64
explained_variance_threshold: float = 0.95
extraction_interval: int = 50
interesting_point_detection: bool = True
loss_plateau_window: int = 10
loss_plateau_threshold: float = 0.005
val_improvement_threshold: float = 0.01
min_weight_dims: int = 2
quantize_to_int8: bool = False
@dataclass
class LoRAMatrices:
"""Low-rank decomposition of a single layer's weight delta."""
layer_name: str
lora_B: torch.Tensor # d × r
lora_A: torch.Tensor # r × k
rank: int
explained_variance: float
singular_values: torch.Tensor
original_shape: tuple[int, ...]
@dataclass
class ExpertAdapter:
"""A complete LoRA adapter extracted from a training checkpoint."""
adapter_id: str
step: int
rank_per_layer: dict[str, int] = field(default_factory=dict)
avg_explained_variance: float = 0.0
# LoRA matrices per layer
layers: dict[str, LoRAMatrices] = field(default_factory=dict)
# Training metrics at extraction point
loss: float = 0.0
sft_loss: float = 0.0
val_score: float = 0.0
external_val_score: float | None = None
optimizer_type: str = ""
spectral_rank: float = 0.0
extraction_trigger: str = "periodic"
wall_clock_time: float = 0.0
# FCES-specific metadata
fces_population_fitness: float | None = None
fces_controller_id: int | None = None
# Domain tags
domain_tags: list[str] = field(default_factory=list)
def total_params(self) -> int:
"""Total number of parameters across all LoRA layers."""
total = 0
for lm in self.layers.values():
total += lm.lora_B.numel() + lm.lora_A.numel()
return total
def compression_ratio(self, original_params: int) -> float:
"""Compression ratio vs full model parameters."""
adapter_params = self.total_params()
if adapter_params == 0:
return float("inf")
return original_params / adapter_params
def to_state_dict(self) -> dict[str, Any]:
"""Serialize to a state dict for saving."""
state: dict[str, Any] = {
"adapter_id": self.adapter_id,
"step": self.step,
"loss": self.loss,
"sft_loss": self.sft_loss,
"val_score": self.val_score,
"external_val_score": self.external_val_score,
"optimizer_type": self.optimizer_type,
"spectral_rank": self.spectral_rank,
"extraction_trigger": self.extraction_trigger,
"wall_clock_time": self.wall_clock_time,
"fces_population_fitness": self.fces_population_fitness,
"fces_controller_id": self.fces_controller_id,
"domain_tags": self.domain_tags,
"avg_explained_variance": self.avg_explained_variance,
}
layers_state: dict[str, dict[str, Any]] = {}
for name, lm in self.layers.items():
layers_state[name] = {
"lora_B": lm.lora_B.cpu(),
"lora_A": lm.lora_A.cpu(),
"rank": lm.rank,
"explained_variance": lm.explained_variance,
"singular_values": lm.singular_values.cpu(),
"original_shape": lm.original_shape,
}
state["layers"] = layers_state
return state
class ParasiticQLoRAExtractor:
"""Extracts LoRA adapters parasitically during full-parameter fine-tuning.
The key insight: during training we already compute forward/backward passes
and (in FCES) SVD for spectral sensing. The weight delta ΔW = W_t - W₀ can
be decomposed into a low-rank LoRA adapter at near-zero marginal cost using
truncated SVD (torch.svd_lowrank).
"""
def __init__(self, config: QLoRAConfig | None = None) -> None:
self.config = config or QLoRAConfig()
self.base_weights: dict[str, torch.Tensor] = {}
self.adapter_library: list[ExpertAdapter] = []
self.loss_history: list[float] = []
self.val_history: list[float] = []
self._extraction_count: int = 0
self._total_extraction_time: float = 0.0
self._total_training_time: float = 0.0
def snapshot_base(self, model: nn.Module) -> None:
"""Snapshot base weights W₀ at training start. Called once."""
self.base_weights.clear()
for name, param in model.named_parameters():
if param.dim() >= self.config.min_weight_dims:
self.base_weights[name] = param.data.clone().cpu()
print(
f"[ParasiticQLoRA] Base snapshot: {len(self.base_weights)} "
f"weight matrices tracked"
)
def should_extract(
self,
step: int,
loss: float,
val_score: float = 0.0,
) -> bool:
"""Determines whether to extract an adapter at this step.
Triggers on:
1. Periodic interval (every N steps)
2. Loss plateau detection
3. Validation score improvement
"""
self.loss_history.append(loss)
if val_score > 0:
self.val_history.append(val_score)
# Always extract at step 0 (pre-training baseline)
if step == 0:
return True
# 1. Periodic extraction
if step % self.config.extraction_interval == 0:
return True
if not self.config.interesting_point_detection:
return False
# 2. Loss plateau detection
window = self.config.loss_plateau_window
if len(self.loss_history) >= window:
recent = self.loss_history[-window:]
loss_range = max(recent) - min(recent)
avg_loss = sum(recent) / len(recent)
if (
avg_loss > 0
and loss_range / avg_loss < self.config.loss_plateau_threshold
):
return True
# 3. Validation improvement
if len(self.val_history) >= 2:
improvement = self.val_history[-1] - self.val_history[-2]
if improvement > self.config.val_improvement_threshold:
return True
return False
def extract_adapters(
self,
model: nn.Module,
step: int,
metrics: dict[str, Any],
) -> ExpertAdapter:
"""Core parasitic extraction. Computes ΔW and decomposes via truncated SVD.
Algorithm per layer:
1. ΔW = W_t - W₀
2. U, Σ, V = svd_lowrank(ΔW, q=max_rank)
3. Select rank r via explained variance threshold
4. B = U_r · diag(√Σ_r), A = diag(√Σ_r) · V_r^T
"""
t0 = time.perf_counter()
self._extraction_count += 1
adapter_id = (
f"{metrics.get('optimizer', 'unknown')}_step{step}_{self._extraction_count}"
)
adapter = ExpertAdapter(
adapter_id=adapter_id,
step=step,
loss=metrics.get("loss", 0.0),
sft_loss=metrics.get("sft_loss", 0.0),
val_score=metrics.get("val_score", 0.0),
external_val_score=metrics.get("external_val_score"),
optimizer_type=metrics.get("optimizer", ""),
spectral_rank=metrics.get("spectral_rank", 0.0),
fces_population_fitness=metrics.get("fces_fitness"),
fces_controller_id=metrics.get("fces_controller_id"),
wall_clock_time=metrics.get("wall_clock_time", 0.0),
)
# Determine extraction trigger
trigger = "periodic"
if step == 0:
trigger = "initial"
elif len(self.loss_history) >= self.config.loss_plateau_window:
recent = self.loss_history[-self.config.loss_plateau_window :]
loss_range = max(recent) - min(recent)
avg_loss = sum(recent) / len(recent)
if (
avg_loss > 0
and loss_range / avg_loss < self.config.loss_plateau_threshold
):
trigger = "plateau"
if len(self.val_history) >= 2:
improvement = self.val_history[-1] - self.val_history[-2]
if improvement > self.config.val_improvement_threshold:
trigger = "val_improvement"
adapter.extraction_trigger = trigger
total_explained = 0.0
n_layers = 0
for name, param in model.named_parameters():
if name not in self.base_weights:
continue
if param.dim() < self.config.min_weight_dims:
continue
base_w = self.base_weights[name].to(param.device)
delta_w = param.data - base_w
# Skip near-zero deltas (layer hasn't changed)
delta_norm = delta_w.norm().item()
if delta_norm < 1e-7:
continue
# Reshape to 2D for SVD if needed (e.g., conv layers)
original_shape = delta_w.shape
if delta_w.dim() > 2:
delta_w = delta_w.reshape(delta_w.shape[0], -1)
d, k = delta_w.shape
max_rank = min(self.config.max_rank, d, k)
# Truncated SVD — O(d·k·r) instead of O(d·k·min(d,k))
try:
U, S, V = torch.svd_lowrank(delta_w.float(), q=max_rank)
except Exception:
# Fallback to full SVD if lowrank fails
U_full, S_full, Vh_full = torch.linalg.svd(
delta_w.float(), full_matrices=False
)
U = U_full[:, :max_rank]
S = S_full[:max_rank]
V = Vh_full[:max_rank, :].T
# Dynamic rank selection via explained variance
total_energy = (S**2).sum().item()
if total_energy < 1e-12:
continue
cumulative_energy = torch.cumsum(S**2, dim=0)
explained_ratios = cumulative_energy / total_energy
# Find rank where we exceed the threshold
rank = self.config.min_rank
for r in range(self.config.min_rank, max_rank + 1):
if (
r <= len(explained_ratios)
and explained_ratios[r - 1].item()
>= self.config.explained_variance_threshold
):
rank = r
break
else:
rank = max_rank
rank = min(rank, len(S))
explained_var = (
explained_ratios[rank - 1].item()
if rank <= len(explained_ratios)
else 1.0
)
# LoRA decomposition: B = U_r · √Σ_r, A = √Σ_r · V_r^T
sqrt_S = torch.sqrt(S[:rank])
lora_B = U[:, :rank] * sqrt_S.unsqueeze(0) # d × r
lora_A = V[:, :rank].T * sqrt_S.unsqueeze(1) # r × k
# Optional INT8 quantization
if self.config.quantize_to_int8:
lora_B = _quantize_int8(lora_B)
lora_A = _quantize_int8(lora_A)
lora_matrices = LoRAMatrices(
layer_name=name,
lora_B=lora_B.cpu().half(),
lora_A=lora_A.cpu().half(),
rank=rank,
explained_variance=explained_var,
singular_values=S[:rank].cpu(),
original_shape=original_shape,
)
adapter.layers[name] = lora_matrices
adapter.rank_per_layer[name] = rank
total_explained += explained_var
n_layers += 1
if n_layers > 0:
adapter.avg_explained_variance = total_explained / n_layers
self.adapter_library.append(adapter)
extraction_time = time.perf_counter() - t0
self._total_extraction_time += extraction_time
print(
f"[ParasiticQLoRA] Extracted adapter '{adapter_id}' | "
f"{n_layers} layers | avg_rank={sum(adapter.rank_per_layer.values()) / max(1, n_layers):.0f} | "
f"explained_var={adapter.avg_explained_variance:.3f} | "
f"params={adapter.total_params():,} | "
f"trigger={trigger} | time={extraction_time:.3f}s"
)
return adapter
def get_overhead_percentage(self) -> float:
"""Returns the extraction overhead as % of total training time."""
if self._total_training_time <= 0:
return 0.0
return (self._total_extraction_time / self._total_training_time) * 100.0
def update_training_time(self, elapsed: float) -> None:
"""Track total training time for overhead calculation."""
self._total_training_time = elapsed
def get_best_adapter(self) -> ExpertAdapter | None:
"""Returns the adapter with the lowest loss."""
if not self.adapter_library:
return None
return min(self.adapter_library, key=lambda a: a.loss)
def get_library_summary(self) -> dict[str, Any]:
"""Returns a summary of the adapter library."""
if not self.adapter_library:
return {"count": 0}
return {
"count": len(self.adapter_library),
"total_extraction_time_s": self._total_extraction_time,
"overhead_pct": self.get_overhead_percentage(),
"adapters": [
{
"id": a.adapter_id,
"step": a.step,
"loss": a.loss,
"n_layers": len(a.layers),
"avg_rank": sum(a.rank_per_layer.values())
/ max(1, len(a.rank_per_layer)),
"explained_var": a.avg_explained_variance,
"params": a.total_params(),
"trigger": a.extraction_trigger,
}
for a in self.adapter_library
],
}
def save_library(self, path: str) -> None:
"""Saves the entire adapter library to disk."""
state = {
"config": {
"min_rank": self.config.min_rank,
"max_rank": self.config.max_rank,
"explained_variance_threshold": self.config.explained_variance_threshold,
},
"adapters": [a.to_state_dict() for a in self.adapter_library],
"extraction_stats": {
"total_extractions": self._extraction_count,
"total_extraction_time_s": self._total_extraction_time,
"overhead_pct": self.get_overhead_percentage(),
},
}
torch.save(state, path)
print(
f"[ParasiticQLoRA] Library saved: {path} ({len(self.adapter_library)} adapters)"
)
@staticmethod
def load_library(path: str) -> list[ExpertAdapter]:
"""Loads adapters from a saved library file."""
state = torch.load(path, map_location="cpu", weights_only=False)
adapters: list[ExpertAdapter] = []
for adapter_state in state["adapters"]:
adapter = ExpertAdapter(
adapter_id=adapter_state["adapter_id"],
step=adapter_state["step"],
loss=adapter_state["loss"],
sft_loss=adapter_state["sft_loss"],
val_score=adapter_state.get("val_score", 0.0),
optimizer_type=adapter_state.get("optimizer_type", ""),
spectral_rank=adapter_state.get("spectral_rank", 0.0),
extraction_trigger=adapter_state.get("extraction_trigger", "unknown"),
avg_explained_variance=adapter_state.get("avg_explained_variance", 0.0),
)
for name, layer_state in adapter_state.get("layers", {}).items():
adapter.layers[name] = LoRAMatrices(
layer_name=name,
lora_B=layer_state["lora_B"],
lora_A=layer_state["lora_A"],
rank=layer_state["rank"],
explained_variance=layer_state["explained_variance"],
singular_values=layer_state["singular_values"],
original_shape=tuple(layer_state["original_shape"]),
)
adapter.rank_per_layer[name] = layer_state["rank"]
adapters.append(adapter)
return adapters
def apply_adapter(
model: nn.Module,
adapter: ExpertAdapter,
scale: float = 1.0,
) -> None:
"""Applies a LoRA adapter to a model (adds the low-rank delta).
W_new = W_base + scale * B @ A
"""
with torch.no_grad():
for name, param in model.named_parameters():
if name not in adapter.layers:
continue
lm = adapter.layers[name]
# Reconstruct: ΔW = B @ A
delta = lm.lora_B.to(param.device, param.dtype) @ lm.lora_A.to(
param.device, param.dtype
)
# Reshape if needed
if delta.shape != param.shape:
delta = delta.reshape(param.shape)
param.data.add_(delta * scale)
def _quantize_int8(tensor: torch.Tensor) -> torch.Tensor:
"""Simple absmax INT8 quantization (dequantized back to float for storage)."""
abs_max = tensor.abs().max()
if abs_max < 1e-10:
return tensor
scale = 127.0 / abs_max
quantized = (tensor * scale).round().clamp(-127, 127)
return quantized / scale