feat: expert manifold alignment, MoE router, FCES controller metadata bindings
This commit is contained in:
@@ -15,6 +15,7 @@ import torch.nn.functional as F # noqa: E402
|
||||
from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
||||
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
||||
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
|
||||
|
||||
# ==============================================================================
|
||||
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
||||
@@ -175,6 +176,9 @@ def train_run(
|
||||
)
|
||||
extractor.snapshot_base(model)
|
||||
|
||||
# Initialize Expert Manifold Aligner
|
||||
aligner = ExpertManifoldAligner(model)
|
||||
|
||||
# 1. Pre-Training Evaluation
|
||||
print(f"[{optimizer_name}] Running Pre-Training Evaluation...")
|
||||
pre_eval = evaluate_model(model, tokenizer, device)
|
||||
@@ -224,15 +228,30 @@ def train_run(
|
||||
if optimizer_name == "FCES":
|
||||
optimizer.update_fitness(float(loss.item()))
|
||||
|
||||
# Track per-step weight delta for manifold alignment
|
||||
aligner.track_step(model)
|
||||
|
||||
# Call parasitic extractor
|
||||
if extractor.should_extract(step, float(loss.item())):
|
||||
metrics = {
|
||||
metrics: Dict[str, Any] = {
|
||||
"loss": float(loss.item()),
|
||||
"sft_loss": float(sft_loss.item()),
|
||||
"optimizer": optimizer_name,
|
||||
"spectral_rank": getattr(optimizer, "last_spectral_rank_", 0.0),
|
||||
}
|
||||
extractor.extract_adapters(model, step, metrics)
|
||||
if optimizer_name == "FCES":
|
||||
metrics["fces_fitness"] = optimizer.get_active_controller_fitness()
|
||||
metrics["fces_controller_id"] = optimizer.get_active_controller_id()
|
||||
adapter = extractor.extract_adapters(model, step, metrics)
|
||||
aligner.tag_adapter(adapter)
|
||||
profile = aligner.profile_adapter(adapter)
|
||||
print(
|
||||
f"[{optimizer_name}] Adapter '{adapter.adapter_id}' | "
|
||||
f"tags={adapter.domain_tags} | "
|
||||
f"statute={profile['statute_recall']:.2f} "
|
||||
f"logic={profile['logic_reasoning']:.2f} "
|
||||
f"style={profile['style_gutachtenstil']:.2f}"
|
||||
)
|
||||
|
||||
# Tracking metrics
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
Reference in New Issue
Block a user