feat: expert manifold alignment, MoE router, FCES controller metadata bindings

This commit is contained in:
AI-anonymous
2026-05-20 16:07:36 +02:00
parent 7e2e86d98c
commit 663e2fb71d
9 changed files with 727 additions and 19 deletions

View File

@@ -15,6 +15,7 @@ import torch.nn.functional as F # noqa: E402
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
# ==============================================================================
# 1. DSPY SIGNATURE & SYSTEM DESIGN
@@ -175,6 +176,9 @@ def train_run(
)
extractor.snapshot_base(model)
# Initialize Expert Manifold Aligner
aligner = ExpertManifoldAligner(model)
# 1. Pre-Training Evaluation
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
pre_eval = evaluate_model(model, tokenizer, device)
@@ -224,15 +228,30 @@ def train_run(
if optimizer_name == "FCES":
optimizer.update_fitness(float(loss.item()))
# Track per-step weight delta for manifold alignment
aligner.track_step(model)
# Call parasitic extractor
if extractor.should_extract(step, float(loss.item())):
metrics = {
metrics: Dict[str, Any] = {
"loss": float(loss.item()),
"sft_loss": float(sft_loss.item()),
"optimizer": optimizer_name,
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
}
extractor.extract_adapters(model, step, metrics)
if optimizer_name == "FCES":
metrics["fces_fitness"] = optimizer.get_active_controller_fitness()
metrics["fces_controller_id"] = optimizer.get_active_controller_id()
adapter = extractor.extract_adapters(model, step, metrics)
aligner.tag_adapter(adapter)
profile = aligner.profile_adapter(adapter)
print(
f"[{optimizer_name}] Adapter '{adapter.adapter_id}' | "
f"tags={adapter.domain_tags} | "
f"statute={profile['statute_recall']:.2f} "
f"logic={profile['logic_reasoning']:.2f} "
f"style={profile['style_gutachtenstil']:.2f}"
)
# Tracking metrics
elapsed = time.perf_counter() - start_time