feat: Add playbook vector extraction and activation steering routing to FCES training pipeline

This commit is contained in:
AI-anonymous
2026-05-23 08:33:27 +02:00
parent e0d8a32823
commit 4c9a550f8b
4 changed files with 884 additions and 46 deletions

View File

@@ -8,11 +8,12 @@ learnable token-level gates.
from __future__ import annotations
import re
from typing import Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter
from representation_engineering import SkillVectorLibrary, ProcessVectorLibrary
class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
@@ -41,25 +42,53 @@ class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
class ExpertAdapterRouter:
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
"""Manages dynamic MoE-style routing over a library of LoRA adapters and representation vectors."""
def __init__(
self,
base_model: nn.Module,
adapter_library: List[ExpertAdapter],
adapter_library: Optional[List[ExpertAdapter]] = None,
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
skill_library: Optional[SkillVectorLibrary] = None,
process_library: Optional[ProcessVectorLibrary] = None,
steering_alpha: float = 1.0,
steering_mode: str = "token", # "token" or "prompt"
) -> None:
self.base_model = base_model
self.adapter_library = adapter_library
self.num_adapters = len(adapter_library)
self.adapter_library = adapter_library or []
self.num_adapters = len(self.adapter_library)
self.skill_library = skill_library
self.process_library = process_library
self.steering_alpha = steering_alpha
self.steering_mode = steering_mode
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
self.active_process_id: Optional[str] = None
self.active_process_step: Optional[int] = None
# Learnable gating network
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
# Sorted list of skill IDs for index-based routing
self.skill_ids = (
sorted(list(self.skill_library.vectors.keys()))
if self.skill_library
else []
)
# Active weights for current forward pass (batch size × num_adapters)
# Learnable gating network for adapters
if self.num_adapters > 0:
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
)
else:
self.gate = None
# Learnable gating network for skills
if len(self.skill_ids) > 0:
self.skill_gate = LearnableGate(in_features, len(self.skill_ids)).to(
next(base_model.parameters()).device
)
else:
self.skill_gate = None
# Active weights for current forward pass (batch size × num_adapters/skills)
self.current_gate_weights: Optional[torch.Tensor] = None
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
@@ -97,30 +126,46 @@ class ExpertAdapterRouter:
self.current_gate_weights = fuzzy_priors
def register_hooks(self) -> None:
"""Attaches forward hooks to linear layers present in the adapter library."""
"""Attaches forward hooks to linear layers (adapters) and transformer blocks (steering)."""
self.unregister_hooks()
# Find all layers in the base model that have adapters
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
# 1. Bind adapter hooks if adapters are present
if self.num_adapters > 0:
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
# Bind hooks dynamically
for name, module in self.base_model.named_modules():
# Check if this specific module has an adapter
# We match using suffix to support model wrapping/prefixes
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
for name, module in self.base_model.named_modules():
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
if matching_adapter_name and isinstance(module, nn.Linear):
if matching_adapter_name and isinstance(module, nn.Linear):
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
)
self.hooks.append(hook)
# 2. Bind steering hooks if skill_library or process_library is present
if self.skill_library or self.process_library:
transformer_layers = []
for name, module in self.base_model.named_modules():
match = re.match(r".*layers?\.(\d+)$", name)
if match:
layer_idx = int(match.group(1))
transformer_layers.append((layer_idx, name, module))
# Sort by layer_idx to ensure consistent mapping
transformer_layers.sort(key=lambda x: x[0])
for layer_idx, name, module in transformer_layers:
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
self._make_steering_hook_fn(layer_idx)
)
self.hooks.append(hook)
@@ -152,38 +197,100 @@ class ExpertAdapterRouter:
)
else:
# Compute dynamically per token via learnable gate
# We pool over sequence length or route per token
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
if self.gate is not None:
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
else:
return output_tensor
# Compute combined low-rank contribution
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
adapter_output = torch.zeros_like(output_tensor)
for i, adapter in enumerate(self.adapter_library):
if layer_name in adapter.layers:
lm = adapter.layers[layer_name]
# Ensure tensors are on the correct device
lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype)
lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype)
# Dynamic scaling: gate_weight for this adapter
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
if len(weights.shape) == 3:
# Token-level routing: shape [batch, seq_len, 1]
g = weights[..., i : i + 1]
else:
# Batch-level routing: shape [batch, 1, 1]
g = weights[:, i].view(-1, 1, 1)
# Low-rank projection
x_proj = torch.matmul(x, lora_A.t())
y_proj = torch.matmul(x_proj, lora_B.t())
# Accumulate scaled delta
adapter_output += g * y_proj
return output_tensor + adapter_output
return hook_fn
def _make_steering_hook_fn(self, layer_idx: int) -> Callable[..., Any]:
"""Creates a hook function to inject activation steering vectors at a specific layer."""
def hook_fn(
module: nn.Module,
input_tensor: Tuple[torch.Tensor, ...],
output_tensor: Any,
) -> Any:
is_tuple = isinstance(output_tensor, tuple)
x = output_tensor[0] if is_tuple else output_tensor
# Sequential process/workflow steering
if self.active_process_id is not None and self.process_library is not None:
step_idx = self.active_process_step or 0
step_vector = self.process_library.get_process_step(
self.active_process_id, step_idx
)
if step_vector and layer_idx in step_vector.layer_vectors:
v = step_vector.layer_vectors[layer_idx].to(
device=x.device, dtype=x.dtype
)
steered_x = x + self.steering_alpha * v
if is_tuple:
return (steered_x,) + output_tensor[1:]
return steered_x
return output_tensor
# Dynamic skill routing
if self.skill_library and len(self.skill_ids) > 0:
weights = None
if self.current_gate_weights is not None:
batch_size = x.shape[0]
weights = (
self.current_gate_weights.to(x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
elif self.skill_gate is not None:
if self.steering_mode == "token":
gate_logits = self.skill_gate(x)
weights = torch.softmax(gate_logits, dim=-1)
else:
x_mean = x.mean(dim=1) if len(x.shape) == 3 else x
gate_logits = self.skill_gate(x_mean)
weights = torch.softmax(gate_logits, dim=-1)
if weights is not None:
steer_contribution = torch.zeros_like(x)
for i, skill_id in enumerate(self.skill_ids):
vec = self.skill_library.get_vector(skill_id)
if vec and layer_idx in vec.layer_vectors:
v = vec.layer_vectors[layer_idx].to(
device=x.device, dtype=x.dtype
)
if len(weights.shape) == 3:
g = weights[..., i : i + 1]
else:
g = weights[:, i].view(-1, 1, 1)
steer_contribution += g * v
steered_x = x + self.steering_alpha * steer_contribution
if is_tuple:
return (steered_x,) + output_tensor[1:]
return steered_x
return output_tensor
return hook_fn

View File

@@ -41,24 +41,29 @@ class PlaybookParser:
logger.warning(f"File not found: {path}")
return None
content = path.read_text(encoding="utf-8")
content = path.read_text(encoding="utf-8").lstrip()
# Parse frontmatter if present
name = path.parent.name if path.parent else "unknown"
description = ""
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
frontmatter_match = re.match(r"^---\r?\n(.*?)\r?\n---\r?\n", content, re.DOTALL)
if frontmatter_match:
fm_text = frontmatter_match.group(1)
name_match = re.search(r"^name:\s*(.*?)$", fm_text, re.MULTILINE)
if name_match:
name = name_match.group(1).strip()
desc_match = re.search(r"^description:\s*(.*?)$", fm_text, re.MULTILINE)
desc_match = re.search(
r"^description:\s*(.*?)(?=\r?\n\w+:|\Z)",
fm_text,
re.DOTALL | re.MULTILINE,
)
if desc_match:
description = desc_match.group(1).strip()
# Remove multiline YAML indicators if any
description = re.sub(r"^>-?\s*", "", description)
description = re.sub(r"\n\s+", " ", description)
description = re.sub(r"^\|\s*", "", description)
description = re.sub(r"\r?\n\s*", " ", description).strip()
# Parse headers if no name/description in frontmatter
if name == "unknown" or not description:
@@ -247,7 +252,8 @@ class RepresentationVectorExtractor:
# Run forward pass for win prompts
for prompt in win_prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
raw_inputs = self.tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in raw_inputs.items()}
with torch.no_grad():
self.model(**inputs)
@@ -264,7 +270,8 @@ class RepresentationVectorExtractor:
# Run forward pass for lose prompts
for prompt in lose_prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
raw_inputs = self.tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in raw_inputs.items()}
with torch.no_grad():
self.model(**inputs)

View File

@@ -0,0 +1,473 @@
"""Orchestration script to compile Skill and Process libraries, train Gating-MLP, and validate steering.
Discovers local playbooks, extracts representation vectors using Pythia-70m (or dummy fallback),
trains the gating network on hidden states, and runs validation.
"""
from __future__ import annotations
import argparse
import logging
import re
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
# Import our representation engineering and router modules
import sys
sys.path.append(str(Path(__file__).parent.absolute()))
from representation_engineering import (
PlaybookParser,
PlaybookMetadata,
RepresentationVectorExtractor,
SkillVectorLibrary,
ProcessVectorLibrary,
)
from adapter_moe_router import LearnableGate, ExpertAdapterRouter
logger = logging.getLogger("run_representation_pipeline")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# --- Dummy Model & Tokenizer for Offline Fallback & Testing ---
class DummyTransformerLayer(nn.Module): # type: ignore[misc]
def __init__(self, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
def forward(
self, x: torch.Tensor, *args: Any, **kwargs: Any
) -> Tuple[torch.Tensor]:
h = self.linear2(torch.relu(self.linear1(x)))
return (x + h,)
class DummyModel(nn.Module): # type: ignore[misc]
def __init__(
self, vocab_size: int = 1000, hidden_dim: int = 32, num_layers: int = 4
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.layers = nn.ModuleList(
[DummyTransformerLayer(hidden_dim) for _ in range(num_layers)]
)
self.lm_head = nn.Linear(hidden_dim, vocab_size)
def forward(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)[0]
logits = self.lm_head(x)
class Output:
def __init__(self, logits: torch.Tensor) -> None:
self.logits = logits
return Output(logits=logits)
class DummyTokenizer:
def __init__(self) -> None:
pass
def __call__(
self, text: str | List[str], return_tensors: str = "pt", **kwargs: Any
) -> Dict[str, torch.Tensor]:
if isinstance(text, str):
text = [text]
batch_ids = []
max_len = 0
for t in text:
words = t.split()
ids = [abs(hash(w)) % 1000 for w in words]
if not ids:
ids = [0]
batch_ids.append(ids)
max_len = max(max_len, len(ids))
padded_ids = []
for ids in batch_ids:
padded_ids.append(ids + [0] * (max_len - len(ids)))
return {"input_ids": torch.tensor(padded_ids)}
# --- Helper to parse process steps from playbooks ---
def parse_playbook_steps(path: Path) -> List[Tuple[str, str]]:
"""Extracts step names and short descriptions from sequential lists in a playbook."""
content = path.read_text(encoding="utf-8")
steps = []
# Locate numbered items under headers like Fallback Procedure, Operational Protocol, etc.
lines = content.splitlines()
in_procedure = False
for line in lines:
line_str = line.strip()
if not line_str:
continue
if line_str.startswith("#"):
lower_header = line_str.lower()
in_procedure = any(
x in lower_header
for x in ["procedure", "protocol", "schritt", "ablauf", "loop"]
)
continue
if in_procedure:
# Match 1. Step name: Description or just 1. Description
match = re.match(r"^(\d+)\.\s*(.*?)$", line_str)
if match:
step_num = match.group(1)
step_text = match.group(2).strip()
# Split step name and description if separated by double asterisks or colon
parts = re.split(r"\*\*|:\s*", step_text, maxsplit=1)
if len(parts) > 1 and parts[0].strip():
step_name = parts[0].strip()
step_desc = parts[1].strip()
else:
step_name = f"Step_{step_num}"
step_desc = step_text
steps.append((step_name, step_desc))
return steps
# --- Pipeline Orchestrator ---
def run_pipeline(
model_id: str,
skills_dirs: List[str],
output_dir: str,
use_dummy: bool = False,
epochs: int = 50,
lr: float = 0.01,
) -> None:
logger.info("Initializing Representation Engineering Pipeline...")
# 1. Resolve output directory
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)
# 2. Load model and tokenizer (with dummy fallback)
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
if not use_dummy:
try:
logger.info(f"Attempting to load model '{model_id}' from HuggingFace...")
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
logger.info("Model loaded successfully.")
except Exception as e:
logger.warning(
f"Failed to load model from HuggingFace: {e}. Falling back to Dummy Model."
)
use_dummy = True
if use_dummy:
logger.info(
"Initializing lightweight Dummy Model and Tokenizer for execution/testing..."
)
model = DummyModel(vocab_size=1000, hidden_dim=32, num_layers=4).to(device)
tokenizer = DummyTokenizer()
device = "cpu"
assert model is not None
assert tokenizer is not None
# 3. Discover playbooks
logger.info("Scanning directories for playbooks...")
playbooks: List[PlaybookMetadata] = []
parsed_files: List[Path] = []
for s_dir in skills_dirs:
path = Path(s_dir)
if path.exists():
logger.info(f"Scanning directory: {path}")
found = PlaybookParser.parse_directory(path)
playbooks.extend(found)
# Find actual files for process parsing
for p_file in path.glob("**/SKILL.md"):
parsed_files.append(p_file)
for p_file in path.glob("*_SKILL.md"):
parsed_files.append(p_file)
if not playbooks:
logger.warning(
"No playbooks discovered! Creating a default mock playbook for execution..."
)
mock_skill = PlaybookMetadata(
name="MockSkill",
description="A temporary skill for pipeline validation",
objectives=[
"Validate the routing pipeline",
"Verify steering functionality",
],
trigger_examples=[
"Run the mock pipeline check",
"Test activation steering on mock",
],
file_path="mock_SKILL.md",
)
playbooks.append(mock_skill)
logger.info(f"Discovered {len(playbooks)} playbooks.")
# 4. Extract Skill and Process Libraries
skill_library = SkillVectorLibrary()
process_library = ProcessVectorLibrary()
extractor = RepresentationVectorExtractor(model, tokenizer, device=device)
# Determine model layers
transformer_layers = []
for name, module in model.named_modules():
if re.match(r".*layers?\.\d+$", name):
transformer_layers.append(name)
num_layers = len(transformer_layers)
logger.info(f"Detected {num_layers} transformer layers in base model.")
# Choose layers to extract: middle/late layers
layers_to_extract = (
list(range(num_layers // 2, num_layers)) if num_layers > 0 else [0]
)
# Extract skill vectors
for pb in playbooks:
logger.info(f"Extracting vectors for skill: {pb.name}")
vec = extractor.extract_steering_vector(pb, layers_to_extract=layers_to_extract)
skill_library.add_vector(vec)
# Extract sequential process vectors if file exists
p_path = Path(pb.file_path)
if p_path.exists():
steps = parse_playbook_steps(p_path)
if steps:
logger.info(
f"Found {len(steps)} sequential steps in process '{pb.name}'"
)
step_vectors = []
for step_name, step_desc in steps:
step_pb = PlaybookMetadata(
name=f"{pb.name}_{step_name}",
description=step_desc,
objectives=[step_desc],
trigger_examples=pb.trigger_examples,
file_path=str(p_path),
)
step_vec = extractor.extract_steering_vector(
step_pb, layers_to_extract=layers_to_extract
)
step_vectors.append(step_vec)
process_library.add_process(
pb.name.lower().replace(" ", "_"), step_vectors
)
# Save libraries
skill_lib_path = out_path / "skill_library.pt"
process_lib_path = out_path / "process_library.pt"
skill_library.save(skill_lib_path)
process_library.save(process_lib_path)
# 5. Train Gating-MLP on hidden states
skill_ids = sorted(list(skill_library.vectors.keys()))
if not skill_ids:
logger.error("No skills available to train Gating-MLP.")
return
logger.info(f"Training Gating-MLP router over {len(skill_ids)} skills...")
# Gating features layer (we extract hidden states from early/middle layer to predict routing)
gate_layer_idx = layers_to_extract[0] if layers_to_extract else 0
hidden_dim = model.hidden_dim if hasattr(model, "hidden_dim") else 32
if hasattr(model, "config") and hasattr(model.config, "hidden_size"):
hidden_dim = model.config.hidden_size
gate_net = LearnableGate(in_features=hidden_dim, num_adapters=len(skill_ids)).to(
device
)
optimizer = optim.Adam(gate_net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# Collect training dataset: (hidden_state, label)
# We pass the win prompts for each skill, hook the gate layer, and collect states
training_data: List[Tuple[torch.Tensor, int]] = []
# Temporary hook to collect states
collected_states: List[torch.Tensor] = []
def collect_hook(module: nn.Module, input_t: Any, output_t: Any) -> None:
x = output_t[0] if isinstance(output_t, tuple) else output_t
# Pool to sequence mean
collected_states.append(x.detach().mean(dim=1).squeeze(0))
# Register hook on gate layer
hook_handle = None
target_layer_name = (
transformer_layers[gate_layer_idx]
if gate_layer_idx < len(transformer_layers)
else None
)
if target_layer_name:
for name, module in model.named_modules():
if name == target_layer_name:
hook_handle = module.register_forward_hook(collect_hook)
break
# Run win prompts to collect activations
for label_idx, skill_id in enumerate(skill_ids):
# Retrieve parsed metadata for this skill ID
pb_match = next(
(p for p in playbooks if p.name.lower().replace(" ", "_") == skill_id), None
)
if pb_match:
for trigger in pb_match.trigger_examples:
win_prompt = (
f"Instructions: You are acting with the following skill: {pb_match.name}.\n"
f"Request: {trigger}\nOutput:"
)
inputs = tokenizer(win_prompt, return_tensors="pt")
# Move inputs to device
inputs = {k: v.to(device) for k, v in inputs.items()}
collected_states.clear()
with torch.no_grad():
model(**inputs)
if collected_states:
state = collected_states[0].cpu() # shape [hidden_dim]
training_data.append((state, label_idx))
if hook_handle:
hook_handle.remove()
if not training_data:
logger.warning(
"Could not collect training data. Using synthetic data to train Gating-MLP."
)
# Fallback to random features for testing compilation flow
for i in range(100):
label = i % len(skill_ids)
feat = torch.randn(hidden_dim) + (label * 2.0) # separate them a bit
training_data.append((feat, label))
# Train MLP
gate_net.train()
X = torch.stack([x for x, y in training_data]).to(device)
Y = torch.tensor([y for x, y in training_data]).to(device)
logger.info(
f"Collected {len(training_data)} training samples. Starting optimization..."
)
for epoch in range(epochs):
optimizer.zero_grad()
outputs = gate_net(X)
loss = criterion(outputs, Y)
loss.backward()
optimizer.step()
if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1:
# Calculate accuracy
_, preds = torch.max(outputs, 1)
correct = (preds == Y).sum().item()
acc = correct / len(Y)
logger.info(
f"Epoch {epoch+1:02d}/{epochs:02d} | Loss: {loss.item():.4f} | Training Accuracy: {acc * 100:.1f}%"
)
# Save gate weights
gate_weights_path = out_path / "gate_weights.pt"
torch.save(gate_net.state_dict(), gate_weights_path)
logger.info(f"Gating MLP weights saved to {gate_weights_path}")
# 6. Validate steerability & routing performance
logger.info("Running pipeline verification...")
router = ExpertAdapterRouter(
base_model=model,
skill_library=skill_library,
process_library=process_library,
in_features=hidden_dim,
steering_alpha=1.0,
)
router.skill_gate.load_state_dict(
torch.load(gate_weights_path, map_location=device)
)
router.register_hooks()
# Run test prompt with routing enabled
test_prompt = "Validate this test prompt"
inputs = tokenizer(test_prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Test that forwarding goes through hooks without crash
with torch.no_grad():
_ = model(**inputs)
logger.info("Successfully executed forward pass with dynamic activation steering.")
router.unregister_hooks()
logger.info("Pipeline run completed successfully.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="FCES Representation Engineering Pipeline"
)
parser.add_argument(
"--model_id",
type=str,
default="EleutherAI/pythia-70m",
help="HuggingFace model identifier",
)
parser.add_argument(
"--skills_dirs",
type=str,
default="C:/Users/Sven/Documents/svenco-knowledge/skills,C:/Users/Sven/Documents/everything-claude-code/skills",
help="Comma-separated paths to search for playbooks",
)
parser.add_argument(
"--output_dir",
type=str,
default="./python/output",
help="Directory to save compiled libraries",
)
parser.add_argument(
"--use_dummy", action="store_true", help="Force using lightweight dummy model"
)
parser.add_argument("--epochs", type=int, default=50, help="Gating training epochs")
parser.add_argument(
"--lr", type=float, default=0.01, help="Gating training learning rate"
)
args = parser.parse_args()
# Split skills dirs
dirs_list = [d.strip() for d in args.skills_dirs.split(",") if d.strip()]
run_pipeline(
model_id=args.model_id,
skills_dirs=dirs_list,
output_dir=args.output_dir,
use_dummy=args.use_dummy,
epochs=args.epochs,
lr=args.lr,
)