"""Parasitic QLoRA Adapter Extraction Engine. Extracts low-rank LoRA adapters from weight deltas during full-parameter fine-tuning at near-zero additional compute cost. Piggybacks on the existing SVD computation in FCES's SpectralSensor. Architecture: During training, W₀ (base weights) is snapshotted at init. At extraction points: ΔW = W_t - W₀ Truncated SVD: ΔW ≈ U_r·Σ_r·V_r^T = B·A (LoRA decomposition) Compute overhead: < 0.5% of training FLOPs when using torch.svd_lowrank. """ from __future__ import annotations import time from dataclasses import dataclass, field from typing import Any import torch import torch.nn as nn @dataclass(frozen=True) class QLoRAConfig: """Configuration for parasitic QLoRA extraction.""" min_rank: int = 8 max_rank: int = 64 explained_variance_threshold: float = 0.95 extraction_interval: int = 50 interesting_point_detection: bool = True loss_plateau_window: int = 10 loss_plateau_threshold: float = 0.005 val_improvement_threshold: float = 0.01 min_weight_dims: int = 2 quantize_to_int8: bool = False @dataclass class LoRAMatrices: """Low-rank decomposition of a single layer's weight delta.""" layer_name: str lora_B: torch.Tensor # d × r lora_A: torch.Tensor # r × k rank: int explained_variance: float singular_values: torch.Tensor original_shape: tuple[int, ...] @dataclass class ExpertAdapter: """A complete LoRA adapter extracted from a training checkpoint.""" adapter_id: str step: int rank_per_layer: dict[str, int] = field(default_factory=dict) avg_explained_variance: float = 0.0 # LoRA matrices per layer layers: dict[str, LoRAMatrices] = field(default_factory=dict) # Training metrics at extraction point loss: float = 0.0 sft_loss: float = 0.0 val_score: float = 0.0 external_val_score: float | None = None optimizer_type: str = "" spectral_rank: float = 0.0 extraction_trigger: str = "periodic" wall_clock_time: float = 0.0 # FCES-specific metadata fces_population_fitness: float | None = None fces_controller_id: int | None = None # Domain tags domain_tags: list[str] = field(default_factory=list) def total_params(self) -> int: """Total number of parameters across all LoRA layers.""" total = 0 for lm in self.layers.values(): total += lm.lora_B.numel() + lm.lora_A.numel() return total def compression_ratio(self, original_params: int) -> float: """Compression ratio vs full model parameters.""" adapter_params = self.total_params() if adapter_params == 0: return float("inf") return original_params / adapter_params def to_state_dict(self) -> dict[str, Any]: """Serialize to a state dict for saving.""" state: dict[str, Any] = { "adapter_id": self.adapter_id, "step": self.step, "loss": self.loss, "sft_loss": self.sft_loss, "val_score": self.val_score, "external_val_score": self.external_val_score, "optimizer_type": self.optimizer_type, "spectral_rank": self.spectral_rank, "extraction_trigger": self.extraction_trigger, "wall_clock_time": self.wall_clock_time, "fces_population_fitness": self.fces_population_fitness, "fces_controller_id": self.fces_controller_id, "domain_tags": self.domain_tags, "avg_explained_variance": self.avg_explained_variance, } layers_state: dict[str, dict[str, Any]] = {} for name, lm in self.layers.items(): layers_state[name] = { "lora_B": lm.lora_B.cpu(), "lora_A": lm.lora_A.cpu(), "rank": lm.rank, "explained_variance": lm.explained_variance, "singular_values": lm.singular_values.cpu(), "original_shape": lm.original_shape, } state["layers"] = layers_state return state class ParasiticQLoRAExtractor: """Extracts LoRA adapters parasitically during full-parameter fine-tuning. The key insight: during training we already compute forward/backward passes and (in FCES) SVD for spectral sensing. The weight delta ΔW = W_t - W₀ can be decomposed into a low-rank LoRA adapter at near-zero marginal cost using truncated SVD (torch.svd_lowrank). """ def __init__(self, config: QLoRAConfig | None = None) -> None: self.config = config or QLoRAConfig() self.base_weights: dict[str, torch.Tensor] = {} self.adapter_library: list[ExpertAdapter] = [] self.loss_history: list[float] = [] self.val_history: list[float] = [] self._extraction_count: int = 0 self._total_extraction_time: float = 0.0 self._total_training_time: float = 0.0 def snapshot_base(self, model: nn.Module) -> None: """Snapshot base weights W₀ at training start. Called once.""" self.base_weights.clear() for name, param in model.named_parameters(): if param.dim() >= self.config.min_weight_dims: self.base_weights[name] = param.data.clone().cpu() print( f"[ParasiticQLoRA] Base snapshot: {len(self.base_weights)} " f"weight matrices tracked" ) def should_extract( self, step: int, loss: float, val_score: float = 0.0, ) -> bool: """Determines whether to extract an adapter at this step. Triggers on: 1. Periodic interval (every N steps) 2. Loss plateau detection 3. Validation score improvement """ self.loss_history.append(loss) if val_score > 0: self.val_history.append(val_score) # Always extract at step 0 (pre-training baseline) if step == 0: return True # 1. Periodic extraction if step % self.config.extraction_interval == 0: return True if not self.config.interesting_point_detection: return False # 2. Loss plateau detection window = self.config.loss_plateau_window if len(self.loss_history) >= window: recent = self.loss_history[-window:] loss_range = max(recent) - min(recent) avg_loss = sum(recent) / len(recent) if ( avg_loss > 0 and loss_range / avg_loss < self.config.loss_plateau_threshold ): return True # 3. Validation improvement if len(self.val_history) >= 2: improvement = self.val_history[-1] - self.val_history[-2] if improvement > self.config.val_improvement_threshold: return True return False def extract_adapters( self, model: nn.Module, step: int, metrics: dict[str, Any], ) -> ExpertAdapter: """Core parasitic extraction. Computes ΔW and decomposes via truncated SVD. Algorithm per layer: 1. ΔW = W_t - W₀ 2. U, Σ, V = svd_lowrank(ΔW, q=max_rank) 3. Select rank r via explained variance threshold 4. B = U_r · diag(√Σ_r), A = diag(√Σ_r) · V_r^T """ t0 = time.perf_counter() self._extraction_count += 1 adapter_id = ( f"{metrics.get('optimizer', 'unknown')}_step{step}_{self._extraction_count}" ) adapter = ExpertAdapter( adapter_id=adapter_id, step=step, loss=metrics.get("loss", 0.0), sft_loss=metrics.get("sft_loss", 0.0), val_score=metrics.get("val_score", 0.0), external_val_score=metrics.get("external_val_score"), optimizer_type=metrics.get("optimizer", ""), spectral_rank=metrics.get("spectral_rank", 0.0), fces_population_fitness=metrics.get("fces_fitness"), fces_controller_id=metrics.get("fces_controller_id"), wall_clock_time=metrics.get("wall_clock_time", 0.0), ) # Determine extraction trigger trigger = "periodic" if step == 0: trigger = "initial" elif len(self.loss_history) >= self.config.loss_plateau_window: recent = self.loss_history[-self.config.loss_plateau_window :] loss_range = max(recent) - min(recent) avg_loss = sum(recent) / len(recent) if ( avg_loss > 0 and loss_range / avg_loss < self.config.loss_plateau_threshold ): trigger = "plateau" if len(self.val_history) >= 2: improvement = self.val_history[-1] - self.val_history[-2] if improvement > self.config.val_improvement_threshold: trigger = "val_improvement" adapter.extraction_trigger = trigger total_explained = 0.0 n_layers = 0 for name, param in model.named_parameters(): if name not in self.base_weights: continue if param.dim() < self.config.min_weight_dims: continue base_w = self.base_weights[name].to(param.device) delta_w = param.data - base_w # Skip near-zero deltas (layer hasn't changed) delta_norm = delta_w.norm().item() if delta_norm < 1e-7: continue # Reshape to 2D for SVD if needed (e.g., conv layers) original_shape = delta_w.shape if delta_w.dim() > 2: delta_w = delta_w.reshape(delta_w.shape[0], -1) d, k = delta_w.shape max_rank = min(self.config.max_rank, d, k) # Truncated SVD — O(d·k·r) instead of O(d·k·min(d,k)) try: U, S, V = torch.svd_lowrank(delta_w.float(), q=max_rank) except Exception: # Fallback to full SVD if lowrank fails U_full, S_full, Vh_full = torch.linalg.svd( delta_w.float(), full_matrices=False ) U = U_full[:, :max_rank] S = S_full[:max_rank] V = Vh_full[:max_rank, :].T # Dynamic rank selection via explained variance total_energy = (S**2).sum().item() if total_energy < 1e-12: continue cumulative_energy = torch.cumsum(S**2, dim=0) explained_ratios = cumulative_energy / total_energy # Find rank where we exceed the threshold rank = self.config.min_rank for r in range(self.config.min_rank, max_rank + 1): if ( r <= len(explained_ratios) and explained_ratios[r - 1].item() >= self.config.explained_variance_threshold ): rank = r break else: rank = max_rank rank = min(rank, len(S)) explained_var = ( explained_ratios[rank - 1].item() if rank <= len(explained_ratios) else 1.0 ) # LoRA decomposition: B = U_r · √Σ_r, A = √Σ_r · V_r^T sqrt_S = torch.sqrt(S[:rank]) lora_B = U[:, :rank] * sqrt_S.unsqueeze(0) # d × r lora_A = V[:, :rank].T * sqrt_S.unsqueeze(1) # r × k # Optional INT8 quantization if self.config.quantize_to_int8: lora_B = _quantize_int8(lora_B) lora_A = _quantize_int8(lora_A) lora_matrices = LoRAMatrices( layer_name=name, lora_B=lora_B.cpu().half(), lora_A=lora_A.cpu().half(), rank=rank, explained_variance=explained_var, singular_values=S[:rank].cpu(), original_shape=original_shape, ) adapter.layers[name] = lora_matrices adapter.rank_per_layer[name] = rank total_explained += explained_var n_layers += 1 if n_layers > 0: adapter.avg_explained_variance = total_explained / n_layers self.adapter_library.append(adapter) extraction_time = time.perf_counter() - t0 self._total_extraction_time += extraction_time print( f"[ParasiticQLoRA] Extracted adapter '{adapter_id}' | " f"{n_layers} layers | avg_rank={sum(adapter.rank_per_layer.values()) / max(1, n_layers):.0f} | " f"explained_var={adapter.avg_explained_variance:.3f} | " f"params={adapter.total_params():,} | " f"trigger={trigger} | time={extraction_time:.3f}s" ) return adapter def get_overhead_percentage(self) -> float: """Returns the extraction overhead as % of total training time.""" if self._total_training_time <= 0: return 0.0 return (self._total_extraction_time / self._total_training_time) * 100.0 def update_training_time(self, elapsed: float) -> None: """Track total training time for overhead calculation.""" self._total_training_time = elapsed def get_best_adapter(self) -> ExpertAdapter | None: """Returns the adapter with the lowest loss.""" if not self.adapter_library: return None return min(self.adapter_library, key=lambda a: a.loss) def get_library_summary(self) -> dict[str, Any]: """Returns a summary of the adapter library.""" if not self.adapter_library: return {"count": 0} return { "count": len(self.adapter_library), "total_extraction_time_s": self._total_extraction_time, "overhead_pct": self.get_overhead_percentage(), "adapters": [ { "id": a.adapter_id, "step": a.step, "loss": a.loss, "n_layers": len(a.layers), "avg_rank": sum(a.rank_per_layer.values()) / max(1, len(a.rank_per_layer)), "explained_var": a.avg_explained_variance, "params": a.total_params(), "trigger": a.extraction_trigger, } for a in self.adapter_library ], } def save_library(self, path: str) -> None: """Saves the entire adapter library to disk.""" state = { "config": { "min_rank": self.config.min_rank, "max_rank": self.config.max_rank, "explained_variance_threshold": self.config.explained_variance_threshold, }, "adapters": [a.to_state_dict() for a in self.adapter_library], "extraction_stats": { "total_extractions": self._extraction_count, "total_extraction_time_s": self._total_extraction_time, "overhead_pct": self.get_overhead_percentage(), }, } torch.save(state, path) print( f"[ParasiticQLoRA] Library saved: {path} ({len(self.adapter_library)} adapters)" ) @staticmethod def load_library(path: str) -> list[ExpertAdapter]: """Loads adapters from a saved library file.""" state = torch.load(path, map_location="cpu", weights_only=False) adapters: list[ExpertAdapter] = [] for adapter_state in state["adapters"]: adapter = ExpertAdapter( adapter_id=adapter_state["adapter_id"], step=adapter_state["step"], loss=adapter_state["loss"], sft_loss=adapter_state["sft_loss"], val_score=adapter_state.get("val_score", 0.0), optimizer_type=adapter_state.get("optimizer_type", ""), spectral_rank=adapter_state.get("spectral_rank", 0.0), extraction_trigger=adapter_state.get("extraction_trigger", "unknown"), avg_explained_variance=adapter_state.get("avg_explained_variance", 0.0), ) for name, layer_state in adapter_state.get("layers", {}).items(): adapter.layers[name] = LoRAMatrices( layer_name=name, lora_B=layer_state["lora_B"], lora_A=layer_state["lora_A"], rank=layer_state["rank"], explained_variance=layer_state["explained_variance"], singular_values=layer_state["singular_values"], original_shape=tuple(layer_state["original_shape"]), ) adapter.rank_per_layer[name] = layer_state["rank"] adapters.append(adapter) return adapters def apply_adapter( model: nn.Module, adapter: ExpertAdapter, scale: float = 1.0, ) -> None: """Applies a LoRA adapter to a model (adds the low-rank delta). W_new = W_base + scale * B @ A """ with torch.no_grad(): for name, param in model.named_parameters(): if name not in adapter.layers: continue lm = adapter.layers[name] # Reconstruct: ΔW = B @ A delta = lm.lora_B.to(param.device, param.dtype) @ lm.lora_A.to( param.device, param.dtype ) # Reshape if needed if delta.shape != param.shape: delta = delta.reshape(param.shape) param.data.add_(delta * scale) def _quantize_int8(tensor: torch.Tensor) -> torch.Tensor: """Simple absmax INT8 quantization (dequantized back to float for storage).""" abs_max = tensor.abs().max() if abs_max < 1e-10: return tensor scale = 127.0 / abs_max quantized = (tensor * scale).round().clamp(-127, 127) return quantized / scale