feat: complete phase 5 - Klausuren validation with MoE routing, dynamic dimensional matching, and full mypy type safety
This commit is contained in:
255
python/telemetry_aggregator.py
Normal file
255
python/telemetry_aggregator.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Multi-Worker Adapter Library Aggregator.
|
||||
|
||||
Merges ExpertAdapter libraries produced by multiple independent workers
|
||||
(Kaggle/RunPod) into a single, deduplicated, domain-balanced corpus.
|
||||
|
||||
Design:
|
||||
- Workers write adapter libraries to a shared path (NFS, S3, or a .pt file).
|
||||
- Aggregator loads all libraries, deduplicates by adapter_id, re-ranks by
|
||||
domain coverage, and emits a consolidated merged_library.pt.
|
||||
- Telemetry is pushed to MariaDB for cluster-wide monitoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.join(_root, "python"))
|
||||
|
||||
from parasitic_qlora import ExpertAdapter, ParasiticQLoRAExtractor # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domain balance scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOMAIN_KEYS: List[str] = ["statute_recall", "logic_reasoning", "style_gutachtenstil"]
|
||||
|
||||
|
||||
def _adapter_domain_score(adapter: ExpertAdapter) -> Dict[str, float]:
|
||||
"""Returns normalised domain score dict for an adapter."""
|
||||
tags = set(adapter.domain_tags)
|
||||
total = max(1, len(tags))
|
||||
return {d: (1.0 / total if d in tags else 0.0) for d in DOMAIN_KEYS}
|
||||
|
||||
|
||||
def _library_coverage(library: List[ExpertAdapter]) -> Dict[str, float]:
|
||||
"""Fraction of adapters covering each domain (0..1)."""
|
||||
if not library:
|
||||
return {d: 0.0 for d in DOMAIN_KEYS}
|
||||
coverage: Dict[str, float] = {d: 0.0 for d in DOMAIN_KEYS}
|
||||
for adapter in library:
|
||||
for d in DOMAIN_KEYS:
|
||||
if d in adapter.domain_tags:
|
||||
coverage[d] += 1.0
|
||||
return {d: v / len(library) for d, v in coverage.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core aggregation logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AdapterLibraryAggregator:
|
||||
"""Merges adapter libraries from multiple workers into a balanced corpus.
|
||||
|
||||
Args:
|
||||
max_adapters_per_domain: Hard cap on adapters per domain in merged lib.
|
||||
min_explained_variance: Discard adapters below this quality threshold.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_adapters_per_domain: int = 20,
|
||||
min_explained_variance: float = 0.90,
|
||||
) -> None:
|
||||
self.max_adapters_per_domain = max_adapters_per_domain
|
||||
self.min_explained_variance = min_explained_variance
|
||||
self._seen_ids: set[str] = set()
|
||||
self.merged: List[ExpertAdapter] = []
|
||||
|
||||
def add_library(self, library: List[ExpertAdapter]) -> Tuple[int, int]:
|
||||
"""Merges a worker library into the aggregator.
|
||||
|
||||
Returns:
|
||||
(accepted, rejected) counts.
|
||||
"""
|
||||
accepted = 0
|
||||
rejected = 0
|
||||
for adapter in library:
|
||||
if adapter.adapter_id in self._seen_ids:
|
||||
rejected += 1
|
||||
continue
|
||||
if adapter.avg_explained_variance < self.min_explained_variance:
|
||||
rejected += 1
|
||||
continue
|
||||
self._seen_ids.add(adapter.adapter_id)
|
||||
self.merged.append(adapter)
|
||||
accepted += 1
|
||||
return accepted, rejected
|
||||
|
||||
def load_and_merge(self, library_paths: List[str]) -> Dict[str, int]:
|
||||
"""Loads all library files and merges them.
|
||||
|
||||
Returns:
|
||||
Summary dict with per-file stats.
|
||||
"""
|
||||
stats: Dict[str, int] = {"total_accepted": 0, "total_rejected": 0, "files": 0}
|
||||
for path in library_paths:
|
||||
if not os.path.exists(path):
|
||||
print(f" [WARN] Library not found: {path}")
|
||||
continue
|
||||
lib = ParasiticQLoRAExtractor.load_library(path)
|
||||
accepted, rejected = self.add_library(lib)
|
||||
stats["total_accepted"] += accepted
|
||||
stats["total_rejected"] += rejected
|
||||
stats["files"] += 1
|
||||
print(
|
||||
f" Loaded {path}: {len(lib)} adapters "
|
||||
f"(+{accepted} accepted, {rejected} rejected)"
|
||||
)
|
||||
return stats
|
||||
|
||||
def rebalance(self) -> List[ExpertAdapter]:
|
||||
"""Selects a domain-balanced subset honouring max_adapters_per_domain.
|
||||
|
||||
Strategy:
|
||||
1. Sort adapters by avg_explained_variance DESC within each domain.
|
||||
2. Round-robin pick from domains until cap is hit.
|
||||
|
||||
Returns:
|
||||
Balanced merged adapter list (also updates self.merged).
|
||||
"""
|
||||
buckets: Dict[str, List[ExpertAdapter]] = {d: [] for d in DOMAIN_KEYS}
|
||||
untagged: List[ExpertAdapter] = []
|
||||
|
||||
for adapter in self.merged:
|
||||
placed = False
|
||||
for d in DOMAIN_KEYS:
|
||||
if d in adapter.domain_tags:
|
||||
buckets[d].append(adapter)
|
||||
placed = True
|
||||
break
|
||||
if not placed:
|
||||
untagged.append(adapter)
|
||||
|
||||
# Sort each bucket by quality
|
||||
for d in DOMAIN_KEYS:
|
||||
buckets[d].sort(key=lambda a: a.avg_explained_variance, reverse=True)
|
||||
buckets[d] = buckets[d][: self.max_adapters_per_domain]
|
||||
|
||||
# Round-robin assemble
|
||||
balanced: List[ExpertAdapter] = []
|
||||
domain_iters = [iter(buckets[d]) for d in DOMAIN_KEYS]
|
||||
active = list(range(len(DOMAIN_KEYS)))
|
||||
while active:
|
||||
next_active = []
|
||||
for idx in active:
|
||||
try:
|
||||
balanced.append(next(domain_iters[idx]))
|
||||
next_active.append(idx)
|
||||
except StopIteration:
|
||||
pass
|
||||
active = next_active
|
||||
|
||||
# Append untagged last (low priority)
|
||||
balanced.extend(untagged)
|
||||
self.merged = balanced
|
||||
return balanced
|
||||
|
||||
def save(self, output_path: str) -> None:
|
||||
"""Saves the merged library using the same format as ParasiticQLoRAExtractor."""
|
||||
state = {
|
||||
"config": {
|
||||
"min_rank": 0,
|
||||
"max_rank": 0,
|
||||
"explained_variance_threshold": self.min_explained_variance,
|
||||
},
|
||||
"adapters": [a.to_state_dict() for a in self.merged],
|
||||
"extraction_stats": {
|
||||
"total_extractions": len(self.merged),
|
||||
"total_extraction_time_s": 0.0,
|
||||
"overhead_pct": 0.0,
|
||||
},
|
||||
}
|
||||
torch.save(state, output_path)
|
||||
print(
|
||||
f"[Aggregator] Saved merged library: {output_path} ({len(self.merged)} adapters)"
|
||||
)
|
||||
|
||||
def coverage_report(self) -> Dict[str, object]:
|
||||
"""Returns a human-readable coverage report."""
|
||||
coverage = _library_coverage(self.merged)
|
||||
return {
|
||||
"total_adapters": len(self.merged),
|
||||
"domain_coverage": coverage,
|
||||
"balance_score": min(coverage.values()) / max(max(coverage.values()), 1e-6),
|
||||
"avg_explained_variance": (
|
||||
sum(a.avg_explained_variance for a in self.merged)
|
||||
/ max(1, len(self.merged))
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point for standalone aggregation runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def aggregate_from_paths(
|
||||
library_paths: List[str],
|
||||
output_path: str,
|
||||
max_per_domain: int = 20,
|
||||
min_ev: float = 0.90,
|
||||
) -> Dict[str, object]:
|
||||
"""Top-level helper: load, merge, rebalance, save, report."""
|
||||
agg = AdapterLibraryAggregator(
|
||||
max_adapters_per_domain=max_per_domain,
|
||||
min_explained_variance=min_ev,
|
||||
)
|
||||
merge_stats = agg.load_and_merge(library_paths)
|
||||
agg.rebalance()
|
||||
agg.save(output_path)
|
||||
report = agg.coverage_report()
|
||||
report.update(merge_stats)
|
||||
return report
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Merge adapter libraries from multiple workers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"libraries", nargs="+", help="Paths to .pt adapter library files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", required=True, help="Output path for merged library"
|
||||
)
|
||||
parser.add_argument("--max-per-domain", type=int, default=20)
|
||||
parser.add_argument("--min-ev", type=float, default=0.90)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Aggregating {len(args.libraries)} libraries...")
|
||||
t0 = time.perf_counter()
|
||||
report = aggregate_from_paths(
|
||||
args.libraries, args.output, args.max_per_domain, args.min_ev
|
||||
)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
print(f"\nMerged Library Report ({elapsed:.1f}s):")
|
||||
print(json.dumps(report, indent=2))
|
||||
print(f"Saved to: {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user