# -*- coding: utf-8 -*- """Multi-Worker Adapter Library Aggregator. Merges ExpertAdapter libraries produced by multiple independent workers (Kaggle/RunPod) into a single, deduplicated, domain-balanced corpus. Design: - Workers write adapter libraries to a shared path (NFS, S3, or a .pt file). - Aggregator loads all libraries, deduplicates by adapter_id, re-ranks by domain coverage, and emits a consolidated merged_library.pt. - Telemetry is pushed to MariaDB for cluster-wide monitoring. """ from __future__ import annotations import json import os import sys import time from typing import Dict, List, Tuple import torch _root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.join(_root, "python")) from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor # noqa: E402 # --------------------------------------------------------------------------- # Domain balance scoring # --------------------------------------------------------------------------- DOMAIN_KEYS: List[str] = ["statute_recall", "logic_reasoning", "style_gutachtenstil"] def _adapter_domain_score(adapter: ExpertAdapter) -> Dict[str, float]: """Returns normalised domain score dict for an adapter.""" tags = set(adapter.domain_tags) total = max(1, len(tags)) return {d: (1.0 / total if d in tags else 0.0) for d in DOMAIN_KEYS} def _library_coverage(library: List[ExpertAdapter]) -> Dict[str, float]: """Fraction of adapters covering each domain (0..1).""" if not library: return {d: 0.0 for d in DOMAIN_KEYS} coverage: Dict[str, float] = {d: 0.0 for d in DOMAIN_KEYS} for adapter in library: for d in DOMAIN_KEYS: if d in adapter.domain_tags: coverage[d] += 1.0 return {d: v / len(library) for d, v in coverage.items()} # --------------------------------------------------------------------------- # Core aggregation logic # --------------------------------------------------------------------------- class AdapterLibraryAggregator: """Merges adapter libraries from multiple workers into a balanced corpus. Args: max_adapters_per_domain: Hard cap on adapters per domain in merged lib. min_explained_variance: Discard adapters below this quality threshold. """ def __init__( self, max_adapters_per_domain: int = 20, min_explained_variance: float = 0.90, ) -> None: self.max_adapters_per_domain = max_adapters_per_domain self.min_explained_variance = min_explained_variance self._seen_ids: set[str] = set() self.merged: List[ExpertAdapter] = [] def add_library(self, library: List[ExpertAdapter]) -> Tuple[int, int]: """Merges a worker library into the aggregator. Returns: (accepted, rejected) counts. """ accepted = 0 rejected = 0 for adapter in library: if adapter.adapter_id in self._seen_ids: rejected += 1 continue if adapter.avg_explained_variance < self.min_explained_variance: rejected += 1 continue self._seen_ids.add(adapter.adapter_id) self.merged.append(adapter) accepted += 1 return accepted, rejected def load_and_merge(self, library_paths: List[str]) -> Dict[str, int]: """Loads all library files and merges them. Returns: Summary dict with per-file stats. """ stats: Dict[str, int] = {"total_accepted": 0, "total_rejected": 0, "files": 0} for path in library_paths: if not os.path.exists(path): print(f" [WARN] Library not found: {path}") continue lib = ParasiticQLoRAExtractor.load_library(path) accepted, rejected = self.add_library(lib) stats["total_accepted"] += accepted stats["total_rejected"] += rejected stats["files"] += 1 print( f" Loaded {path}: {len(lib)} adapters " f"(+{accepted} accepted, {rejected} rejected)" ) return stats def rebalance(self) -> List[ExpertAdapter]: """Selects a domain-balanced subset honouring max_adapters_per_domain. Strategy: 1. Sort adapters by avg_explained_variance DESC within each domain. 2. Round-robin pick from domains until cap is hit. Returns: Balanced merged adapter list (also updates self.merged). """ buckets: Dict[str, List[ExpertAdapter]] = {d: [] for d in DOMAIN_KEYS} untagged: List[ExpertAdapter] = [] for adapter in self.merged: placed = False for d in DOMAIN_KEYS: if d in adapter.domain_tags: buckets[d].append(adapter) placed = True break if not placed: untagged.append(adapter) # Sort each bucket by quality for d in DOMAIN_KEYS: buckets[d].sort(key=lambda a: a.avg_explained_variance, reverse=True) buckets[d] = buckets[d][: self.max_adapters_per_domain] # Round-robin assemble balanced: List[ExpertAdapter] = [] domain_iters = [iter(buckets[d]) for d in DOMAIN_KEYS] active = list(range(len(DOMAIN_KEYS))) while active: next_active = [] for idx in active: try: balanced.append(next(domain_iters[idx])) next_active.append(idx) except StopIteration: pass active = next_active # Append untagged last (low priority) balanced.extend(untagged) self.merged = balanced return balanced def save(self, output_path: str) -> None: """Saves the merged library using the same format as ParasiticQLoRAExtractor.""" state = { "config": { "min_rank": 0, "max_rank": 0, "explained_variance_threshold": self.min_explained_variance, }, "adapters": [a.to_state_dict() for a in self.merged], "extraction_stats": { "total_extractions": len(self.merged), "total_extraction_time_s": 0.0, "overhead_pct": 0.0, }, } torch.save(state, output_path) print( f"[Aggregator] Saved merged library: {output_path} ({len(self.merged)} adapters)" ) def coverage_report(self) -> Dict[str, object]: """Returns a human-readable coverage report.""" coverage = _library_coverage(self.merged) return { "total_adapters": len(self.merged), "domain_coverage": coverage, "balance_score": min(coverage.values()) / max(max(coverage.values()), 1e-6), "avg_explained_variance": ( sum(a.avg_explained_variance for a in self.merged) / max(1, len(self.merged)) ), } # --------------------------------------------------------------------------- # CLI entry point for standalone aggregation runs # --------------------------------------------------------------------------- def aggregate_from_paths( library_paths: List[str], output_path: str, max_per_domain: int = 20, min_ev: float = 0.90, ) -> Dict[str, object]: """Top-level helper: load, merge, rebalance, save, report.""" agg = AdapterLibraryAggregator( max_adapters_per_domain=max_per_domain, min_explained_variance=min_ev, ) merge_stats = agg.load_and_merge(library_paths) agg.rebalance() agg.save(output_path) report = agg.coverage_report() report.update(merge_stats) return report def main() -> None: import argparse parser = argparse.ArgumentParser( description="Merge adapter libraries from multiple workers" ) parser.add_argument( "libraries", nargs="+", help="Paths to .pt adapter library files" ) parser.add_argument( "--output", required=True, help="Output path for merged library" ) parser.add_argument("--max-per-domain", type=int, default=20) parser.add_argument("--min-ev", type=float, default=0.90) args = parser.parse_args() print(f"Aggregating {len(args.libraries)} libraries...") t0 = time.perf_counter() report = aggregate_from_paths( args.libraries, args.output, args.max_per_domain, args.min_ev ) elapsed = time.perf_counter() - t0 print(f"\nMerged Library Report ({elapsed:.1f}s):") print(json.dumps(report, indent=2)) print(f"Saved to: {args.output}") if __name__ == "__main__": main()