256 lines
8.6 KiB
Python
256 lines
8.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Multi-Worker Adapter Library Aggregator.
|
|
|
|
Merges ExpertAdapter libraries produced by multiple independent workers
|
|
(Kaggle/RunPod) into a single, deduplicated, domain-balanced corpus.
|
|
|
|
Design:
|
|
- Workers write adapter libraries to a shared path (NFS, S3, or a .pt file).
|
|
- Aggregator loads all libraries, deduplicates by adapter_id, re-ranks by
|
|
domain coverage, and emits a consolidated merged_library.pt.
|
|
- Telemetry is pushed to MariaDB for cluster-wide monitoring.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
|
|
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, os.path.join(_root, "python"))
|
|
|
|
from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor # noqa: E402
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Domain balance scoring
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DOMAIN_KEYS: List[str] = ["statute_recall", "logic_reasoning", "style_gutachtenstil"]
|
|
|
|
|
|
def _adapter_domain_score(adapter: ExpertAdapter) -> Dict[str, float]:
|
|
"""Returns normalised domain score dict for an adapter."""
|
|
tags = set(adapter.domain_tags)
|
|
total = max(1, len(tags))
|
|
return {d: (1.0 / total if d in tags else 0.0) for d in DOMAIN_KEYS}
|
|
|
|
|
|
def _library_coverage(library: List[ExpertAdapter]) -> Dict[str, float]:
|
|
"""Fraction of adapters covering each domain (0..1)."""
|
|
if not library:
|
|
return {d: 0.0 for d in DOMAIN_KEYS}
|
|
coverage: Dict[str, float] = {d: 0.0 for d in DOMAIN_KEYS}
|
|
for adapter in library:
|
|
for d in DOMAIN_KEYS:
|
|
if d in adapter.domain_tags:
|
|
coverage[d] += 1.0
|
|
return {d: v / len(library) for d, v in coverage.items()}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Core aggregation logic
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class AdapterLibraryAggregator:
|
|
"""Merges adapter libraries from multiple workers into a balanced corpus.
|
|
|
|
Args:
|
|
max_adapters_per_domain: Hard cap on adapters per domain in merged lib.
|
|
min_explained_variance: Discard adapters below this quality threshold.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_adapters_per_domain: int = 20,
|
|
min_explained_variance: float = 0.90,
|
|
) -> None:
|
|
self.max_adapters_per_domain = max_adapters_per_domain
|
|
self.min_explained_variance = min_explained_variance
|
|
self._seen_ids: set[str] = set()
|
|
self.merged: List[ExpertAdapter] = []
|
|
|
|
def add_library(self, library: List[ExpertAdapter]) -> Tuple[int, int]:
|
|
"""Merges a worker library into the aggregator.
|
|
|
|
Returns:
|
|
(accepted, rejected) counts.
|
|
"""
|
|
accepted = 0
|
|
rejected = 0
|
|
for adapter in library:
|
|
if adapter.adapter_id in self._seen_ids:
|
|
rejected += 1
|
|
continue
|
|
if adapter.avg_explained_variance < self.min_explained_variance:
|
|
rejected += 1
|
|
continue
|
|
self._seen_ids.add(adapter.adapter_id)
|
|
self.merged.append(adapter)
|
|
accepted += 1
|
|
return accepted, rejected
|
|
|
|
def load_and_merge(self, library_paths: List[str]) -> Dict[str, int]:
|
|
"""Loads all library files and merges them.
|
|
|
|
Returns:
|
|
Summary dict with per-file stats.
|
|
"""
|
|
stats: Dict[str, int] = {"total_accepted": 0, "total_rejected": 0, "files": 0}
|
|
for path in library_paths:
|
|
if not os.path.exists(path):
|
|
print(f" [WARN] Library not found: {path}")
|
|
continue
|
|
lib = ParasiticQLoRAExtractor.load_library(path)
|
|
accepted, rejected = self.add_library(lib)
|
|
stats["total_accepted"] += accepted
|
|
stats["total_rejected"] += rejected
|
|
stats["files"] += 1
|
|
print(
|
|
f" Loaded {path}: {len(lib)} adapters "
|
|
f"(+{accepted} accepted, {rejected} rejected)"
|
|
)
|
|
return stats
|
|
|
|
def rebalance(self) -> List[ExpertAdapter]:
|
|
"""Selects a domain-balanced subset honouring max_adapters_per_domain.
|
|
|
|
Strategy:
|
|
1. Sort adapters by avg_explained_variance DESC within each domain.
|
|
2. Round-robin pick from domains until cap is hit.
|
|
|
|
Returns:
|
|
Balanced merged adapter list (also updates self.merged).
|
|
"""
|
|
buckets: Dict[str, List[ExpertAdapter]] = {d: [] for d in DOMAIN_KEYS}
|
|
untagged: List[ExpertAdapter] = []
|
|
|
|
for adapter in self.merged:
|
|
placed = False
|
|
for d in DOMAIN_KEYS:
|
|
if d in adapter.domain_tags:
|
|
buckets[d].append(adapter)
|
|
placed = True
|
|
break
|
|
if not placed:
|
|
untagged.append(adapter)
|
|
|
|
# Sort each bucket by quality
|
|
for d in DOMAIN_KEYS:
|
|
buckets[d].sort(key=lambda a: a.avg_explained_variance, reverse=True)
|
|
buckets[d] = buckets[d][: self.max_adapters_per_domain]
|
|
|
|
# Round-robin assemble
|
|
balanced: List[ExpertAdapter] = []
|
|
domain_iters = [iter(buckets[d]) for d in DOMAIN_KEYS]
|
|
active = list(range(len(DOMAIN_KEYS)))
|
|
while active:
|
|
next_active = []
|
|
for idx in active:
|
|
try:
|
|
balanced.append(next(domain_iters[idx]))
|
|
next_active.append(idx)
|
|
except StopIteration:
|
|
pass
|
|
active = next_active
|
|
|
|
# Append untagged last (low priority)
|
|
balanced.extend(untagged)
|
|
self.merged = balanced
|
|
return balanced
|
|
|
|
def save(self, output_path: str) -> None:
|
|
"""Saves the merged library using the same format as ParasiticQLoRAExtractor."""
|
|
state = {
|
|
"config": {
|
|
"min_rank": 0,
|
|
"max_rank": 0,
|
|
"explained_variance_threshold": self.min_explained_variance,
|
|
},
|
|
"adapters": [a.to_state_dict() for a in self.merged],
|
|
"extraction_stats": {
|
|
"total_extractions": len(self.merged),
|
|
"total_extraction_time_s": 0.0,
|
|
"overhead_pct": 0.0,
|
|
},
|
|
}
|
|
torch.save(state, output_path)
|
|
print(
|
|
f"[Aggregator] Saved merged library: {output_path} ({len(self.merged)} adapters)"
|
|
)
|
|
|
|
def coverage_report(self) -> Dict[str, object]:
|
|
"""Returns a human-readable coverage report."""
|
|
coverage = _library_coverage(self.merged)
|
|
return {
|
|
"total_adapters": len(self.merged),
|
|
"domain_coverage": coverage,
|
|
"balance_score": min(coverage.values()) / max(max(coverage.values()), 1e-6),
|
|
"avg_explained_variance": (
|
|
sum(a.avg_explained_variance for a in self.merged)
|
|
/ max(1, len(self.merged))
|
|
),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI entry point for standalone aggregation runs
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def aggregate_from_paths(
|
|
library_paths: List[str],
|
|
output_path: str,
|
|
max_per_domain: int = 20,
|
|
min_ev: float = 0.90,
|
|
) -> Dict[str, object]:
|
|
"""Top-level helper: load, merge, rebalance, save, report."""
|
|
agg = AdapterLibraryAggregator(
|
|
max_adapters_per_domain=max_per_domain,
|
|
min_explained_variance=min_ev,
|
|
)
|
|
merge_stats = agg.load_and_merge(library_paths)
|
|
agg.rebalance()
|
|
agg.save(output_path)
|
|
report = agg.coverage_report()
|
|
report.update(merge_stats)
|
|
return report
|
|
|
|
|
|
def main() -> None:
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Merge adapter libraries from multiple workers"
|
|
)
|
|
parser.add_argument(
|
|
"libraries", nargs="+", help="Paths to .pt adapter library files"
|
|
)
|
|
parser.add_argument(
|
|
"--output", required=True, help="Output path for merged library"
|
|
)
|
|
parser.add_argument("--max-per-domain", type=int, default=20)
|
|
parser.add_argument("--min-ev", type=float, default=0.90)
|
|
args = parser.parse_args()
|
|
|
|
print(f"Aggregating {len(args.libraries)} libraries...")
|
|
t0 = time.perf_counter()
|
|
report = aggregate_from_paths(
|
|
args.libraries, args.output, args.max_per_domain, args.min_ev
|
|
)
|
|
elapsed = time.perf_counter() - t0
|
|
|
|
print(f"\nMerged Library Report ({elapsed:.1f}s):")
|
|
print(json.dumps(report, indent=2))
|
|
print(f"Saved to: {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|