Compare commits
3 Commits
306372bb5b
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c9a550f8b | ||
|
|
e0d8a32823 | ||
|
|
c6ba37dc39 |
@@ -16,6 +16,11 @@ from send_telemetry import push_to_mariadb, push_to_surrealdb # noqa: E402
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
|
||||
from parasitic_qlora import ParasiticQLoRAExtractor, QLoRAConfig # noqa: E402
|
||||
from expert_manifold_alignment import ExpertManifoldAligner # noqa: E402
|
||||
from representation_engineering import ( # noqa: E402
|
||||
PlaybookParser,
|
||||
RepresentationVectorExtractor,
|
||||
SkillVectorLibrary,
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
# 1. DSPY SIGNATURE & SYSTEM DESIGN
|
||||
@@ -278,6 +283,27 @@ def train_run(
|
||||
library_path = f"parasitic_adapters_{optimizer_name.lower()}_step{steps}.pt"
|
||||
extractor.save_library(library_path)
|
||||
|
||||
# Extract and save Skill representation vectors
|
||||
print(f"[{optimizer_name}] Extracting Skill representation vectors...")
|
||||
try:
|
||||
skills_dir = "C:/Users/Sven/Documents/svenco-knowledge/skills"
|
||||
playbooks = PlaybookParser.parse_directory(skills_dir)
|
||||
print(f"[{optimizer_name}] Found {len(playbooks)} playbooks in {skills_dir}")
|
||||
|
||||
rep_extractor = RepresentationVectorExtractor(model, tokenizer, device)
|
||||
skill_lib = SkillVectorLibrary()
|
||||
|
||||
for pb in playbooks:
|
||||
print(f"[{optimizer_name}] Extracting steering vector for skill: {pb.name}")
|
||||
vec = rep_extractor.extract_steering_vector(pb)
|
||||
skill_lib.add_vector(vec)
|
||||
|
||||
skill_lib_path = f"skill_library_{optimizer_name.lower()}.pt"
|
||||
skill_lib.save(skill_lib_path)
|
||||
print(f"[{optimizer_name}] Saved SkillVectorLibrary to {skill_lib_path}")
|
||||
except Exception as e:
|
||||
print(f"[{optimizer_name}] Error extracting skill representation vectors: {e}")
|
||||
|
||||
# 4. Post-Training Evaluation
|
||||
print(f"[{optimizer_name}] Running Post-Training Evaluation...")
|
||||
post_eval = evaluate_model(model, tokenizer, device)
|
||||
|
||||
@@ -8,11 +8,12 @@ learnable token-level gates.
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from parasitic_qlora import ExpertAdapter
|
||||
from representation_engineering import SkillVectorLibrary, ProcessVectorLibrary
|
||||
|
||||
|
||||
class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
|
||||
@@ -41,25 +42,53 @@ class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
|
||||
|
||||
|
||||
class ExpertAdapterRouter:
|
||||
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
|
||||
"""Manages dynamic MoE-style routing over a library of LoRA adapters and representation vectors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model: nn.Module,
|
||||
adapter_library: List[ExpertAdapter],
|
||||
adapter_library: Optional[List[ExpertAdapter]] = None,
|
||||
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
|
||||
skill_library: Optional[SkillVectorLibrary] = None,
|
||||
process_library: Optional[ProcessVectorLibrary] = None,
|
||||
steering_alpha: float = 1.0,
|
||||
steering_mode: str = "token", # "token" or "prompt"
|
||||
) -> None:
|
||||
self.base_model = base_model
|
||||
self.adapter_library = adapter_library
|
||||
self.num_adapters = len(adapter_library)
|
||||
self.adapter_library = adapter_library or []
|
||||
self.num_adapters = len(self.adapter_library)
|
||||
self.skill_library = skill_library
|
||||
self.process_library = process_library
|
||||
self.steering_alpha = steering_alpha
|
||||
self.steering_mode = steering_mode
|
||||
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
|
||||
self.active_process_id: Optional[str] = None
|
||||
self.active_process_step: Optional[int] = None
|
||||
|
||||
# Learnable gating network
|
||||
# Sorted list of skill IDs for index-based routing
|
||||
self.skill_ids = (
|
||||
sorted(list(self.skill_library.vectors.keys()))
|
||||
if self.skill_library
|
||||
else []
|
||||
)
|
||||
|
||||
# Learnable gating network for adapters
|
||||
if self.num_adapters > 0:
|
||||
self.gate = LearnableGate(in_features, self.num_adapters).to(
|
||||
next(base_model.parameters()).device
|
||||
)
|
||||
else:
|
||||
self.gate = None
|
||||
|
||||
# Active weights for current forward pass (batch size × num_adapters)
|
||||
# Learnable gating network for skills
|
||||
if len(self.skill_ids) > 0:
|
||||
self.skill_gate = LearnableGate(in_features, len(self.skill_ids)).to(
|
||||
next(base_model.parameters()).device
|
||||
)
|
||||
else:
|
||||
self.skill_gate = None
|
||||
|
||||
# Active weights for current forward pass (batch size × num_adapters/skills)
|
||||
self.current_gate_weights: Optional[torch.Tensor] = None
|
||||
|
||||
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
|
||||
@@ -97,18 +126,16 @@ class ExpertAdapterRouter:
|
||||
self.current_gate_weights = fuzzy_priors
|
||||
|
||||
def register_hooks(self) -> None:
|
||||
"""Attaches forward hooks to linear layers present in the adapter library."""
|
||||
"""Attaches forward hooks to linear layers (adapters) and transformer blocks (steering)."""
|
||||
self.unregister_hooks()
|
||||
|
||||
# Find all layers in the base model that have adapters
|
||||
# 1. Bind adapter hooks if adapters are present
|
||||
if self.num_adapters > 0:
|
||||
adapter_layers: set[str] = set()
|
||||
for adapter in self.adapter_library:
|
||||
adapter_layers.update(adapter.layers.keys())
|
||||
|
||||
# Bind hooks dynamically
|
||||
for name, module in self.base_model.named_modules():
|
||||
# Check if this specific module has an adapter
|
||||
# We match using suffix to support model wrapping/prefixes
|
||||
matching_adapter_name = None
|
||||
for layer_name in adapter_layers:
|
||||
clean_layer_name = layer_name.replace(".weight", "").replace(
|
||||
@@ -124,6 +151,24 @@ class ExpertAdapterRouter:
|
||||
)
|
||||
self.hooks.append(hook)
|
||||
|
||||
# 2. Bind steering hooks if skill_library or process_library is present
|
||||
if self.skill_library or self.process_library:
|
||||
transformer_layers = []
|
||||
for name, module in self.base_model.named_modules():
|
||||
match = re.match(r".*layers?\.(\d+)$", name)
|
||||
if match:
|
||||
layer_idx = int(match.group(1))
|
||||
transformer_layers.append((layer_idx, name, module))
|
||||
|
||||
# Sort by layer_idx to ensure consistent mapping
|
||||
transformer_layers.sort(key=lambda x: x[0])
|
||||
|
||||
for layer_idx, name, module in transformer_layers:
|
||||
hook = module.register_forward_hook(
|
||||
self._make_steering_hook_fn(layer_idx)
|
||||
)
|
||||
self.hooks.append(hook)
|
||||
|
||||
def unregister_hooks(self) -> None:
|
||||
"""Removes all registered hooks from the base model."""
|
||||
for hook in self.hooks:
|
||||
@@ -152,38 +197,100 @@ class ExpertAdapterRouter:
|
||||
)
|
||||
else:
|
||||
# Compute dynamically per token via learnable gate
|
||||
# We pool over sequence length or route per token
|
||||
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
|
||||
if self.gate is not None:
|
||||
gate_logits = self.gate(x)
|
||||
weights = torch.softmax(gate_logits, dim=-1)
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
# Compute combined low-rank contribution
|
||||
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
|
||||
adapter_output = torch.zeros_like(output_tensor)
|
||||
|
||||
for i, adapter in enumerate(self.adapter_library):
|
||||
if layer_name in adapter.layers:
|
||||
lm = adapter.layers[layer_name]
|
||||
# Ensure tensors are on the correct device
|
||||
lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype)
|
||||
lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
# Dynamic scaling: gate_weight for this adapter
|
||||
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
|
||||
if len(weights.shape) == 3:
|
||||
# Token-level routing: shape [batch, seq_len, 1]
|
||||
g = weights[..., i : i + 1]
|
||||
else:
|
||||
# Batch-level routing: shape [batch, 1, 1]
|
||||
g = weights[:, i].view(-1, 1, 1)
|
||||
|
||||
# Low-rank projection
|
||||
x_proj = torch.matmul(x, lora_A.t())
|
||||
y_proj = torch.matmul(x_proj, lora_B.t())
|
||||
|
||||
# Accumulate scaled delta
|
||||
adapter_output += g * y_proj
|
||||
|
||||
return output_tensor + adapter_output
|
||||
|
||||
return hook_fn
|
||||
|
||||
def _make_steering_hook_fn(self, layer_idx: int) -> Callable[..., Any]:
|
||||
"""Creates a hook function to inject activation steering vectors at a specific layer."""
|
||||
|
||||
def hook_fn(
|
||||
module: nn.Module,
|
||||
input_tensor: Tuple[torch.Tensor, ...],
|
||||
output_tensor: Any,
|
||||
) -> Any:
|
||||
is_tuple = isinstance(output_tensor, tuple)
|
||||
x = output_tensor[0] if is_tuple else output_tensor
|
||||
|
||||
# Sequential process/workflow steering
|
||||
if self.active_process_id is not None and self.process_library is not None:
|
||||
step_idx = self.active_process_step or 0
|
||||
step_vector = self.process_library.get_process_step(
|
||||
self.active_process_id, step_idx
|
||||
)
|
||||
if step_vector and layer_idx in step_vector.layer_vectors:
|
||||
v = step_vector.layer_vectors[layer_idx].to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
steered_x = x + self.steering_alpha * v
|
||||
if is_tuple:
|
||||
return (steered_x,) + output_tensor[1:]
|
||||
return steered_x
|
||||
return output_tensor
|
||||
|
||||
# Dynamic skill routing
|
||||
if self.skill_library and len(self.skill_ids) > 0:
|
||||
weights = None
|
||||
if self.current_gate_weights is not None:
|
||||
batch_size = x.shape[0]
|
||||
weights = (
|
||||
self.current_gate_weights.to(x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, -1)
|
||||
)
|
||||
elif self.skill_gate is not None:
|
||||
if self.steering_mode == "token":
|
||||
gate_logits = self.skill_gate(x)
|
||||
weights = torch.softmax(gate_logits, dim=-1)
|
||||
else:
|
||||
x_mean = x.mean(dim=1) if len(x.shape) == 3 else x
|
||||
gate_logits = self.skill_gate(x_mean)
|
||||
weights = torch.softmax(gate_logits, dim=-1)
|
||||
|
||||
if weights is not None:
|
||||
steer_contribution = torch.zeros_like(x)
|
||||
for i, skill_id in enumerate(self.skill_ids):
|
||||
vec = self.skill_library.get_vector(skill_id)
|
||||
if vec and layer_idx in vec.layer_vectors:
|
||||
v = vec.layer_vectors[layer_idx].to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
if len(weights.shape) == 3:
|
||||
g = weights[..., i : i + 1]
|
||||
else:
|
||||
g = weights[:, i].view(-1, 1, 1)
|
||||
steer_contribution += g * v
|
||||
|
||||
steered_x = x + self.steering_alpha * steer_contribution
|
||||
if is_tuple:
|
||||
return (steered_x,) + output_tensor[1:]
|
||||
return steered_x
|
||||
|
||||
return output_tensor
|
||||
|
||||
return hook_fn
|
||||
|
||||
420
python/representation_engineering.py
Normal file
420
python/representation_engineering.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""Representation Engineering and Vector Extraction from Playbooks.
|
||||
|
||||
Parses playbooks (SKILL.md), extracts minimal pairs, and computes steerable
|
||||
activation vectors or QLoRA adapters to build Skill and Process Vector Libraries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger("representation_engineering")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaybookMetadata:
|
||||
"""Metadata and examples parsed from a SKILL.md playbook."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
objectives: List[str] = field(default_factory=list)
|
||||
trigger_examples: List[str] = field(default_factory=list)
|
||||
file_path: str = ""
|
||||
|
||||
|
||||
class PlaybookParser:
|
||||
"""Parses SKILL.md files to extract structured metadata and trigger examples."""
|
||||
|
||||
@staticmethod
|
||||
def parse_file(path: str | Path) -> PlaybookMetadata | None:
|
||||
"""Parses a single SKILL.md file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
logger.warning(f"File not found: {path}")
|
||||
return None
|
||||
|
||||
content = path.read_text(encoding="utf-8").lstrip()
|
||||
|
||||
# Parse frontmatter if present
|
||||
name = path.parent.name if path.parent else "unknown"
|
||||
description = ""
|
||||
|
||||
frontmatter_match = re.match(r"^---\r?\n(.*?)\r?\n---\r?\n", content, re.DOTALL)
|
||||
if frontmatter_match:
|
||||
fm_text = frontmatter_match.group(1)
|
||||
name_match = re.search(r"^name:\s*(.*?)$", fm_text, re.MULTILINE)
|
||||
if name_match:
|
||||
name = name_match.group(1).strip()
|
||||
desc_match = re.search(
|
||||
r"^description:\s*(.*?)(?=\r?\n\w+:|\Z)",
|
||||
fm_text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
if desc_match:
|
||||
description = desc_match.group(1).strip()
|
||||
# Remove multiline YAML indicators if any
|
||||
description = re.sub(r"^>-?\s*", "", description)
|
||||
description = re.sub(r"^\|\s*", "", description)
|
||||
description = re.sub(r"\r?\n\s*", " ", description).strip()
|
||||
|
||||
# Parse headers if no name/description in frontmatter
|
||||
if name == "unknown" or not description:
|
||||
title_match = re.search(r"^#\s+(.*?)$", content, re.MULTILINE)
|
||||
if title_match and name == "unknown":
|
||||
name = title_match.group(1).strip()
|
||||
# Remove prefixes like "Skill:" or "SOP:"
|
||||
name = re.sub(
|
||||
r"^(Skill|SOP|Playbook):\s*", "", name, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Objectives and triggers
|
||||
objectives: List[str] = []
|
||||
triggers: List[str] = []
|
||||
|
||||
# Simple regex searches for bullet points
|
||||
lines = content.splitlines()
|
||||
in_objectives = False
|
||||
in_triggers = False
|
||||
|
||||
for line in lines:
|
||||
line_str = line.strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
# Section detection
|
||||
if line_str.startswith("#"):
|
||||
in_objectives = (
|
||||
"objective" in line_str.lower() or "ziel" in line_str.lower()
|
||||
)
|
||||
in_triggers = any(
|
||||
x in line_str.lower()
|
||||
for x in ["trigger", "when to use", "examples", "beispiele"]
|
||||
)
|
||||
continue
|
||||
|
||||
if in_objectives and (
|
||||
line_str.startswith("-")
|
||||
or line_str.startswith("*")
|
||||
or re.match(r"^\d+\.", line_str)
|
||||
):
|
||||
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
|
||||
objectives.append(clean_line)
|
||||
|
||||
if in_triggers and (
|
||||
line_str.startswith("-")
|
||||
or line_str.startswith("*")
|
||||
or re.match(r"^\d+\.", line_str)
|
||||
):
|
||||
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
|
||||
# Filter out generic instructions
|
||||
if len(clean_line) > 5 and not clean_line.lower().startswith("do not"):
|
||||
triggers.append(clean_line)
|
||||
|
||||
# Fallback if no triggers found
|
||||
if not triggers:
|
||||
triggers = [
|
||||
f"Apply the {name} skill to handle this task.",
|
||||
f"How do I use {name} here?",
|
||||
]
|
||||
|
||||
return PlaybookMetadata(
|
||||
name=name,
|
||||
description=description or f"Playbook for {name}",
|
||||
objectives=objectives,
|
||||
trigger_examples=triggers,
|
||||
file_path=str(path.absolute()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_directory(cls, dir_path: str | Path) -> List[PlaybookMetadata]:
|
||||
"""Scans a directory for SKILL.md or matches files and parses them."""
|
||||
dir_path = Path(dir_path)
|
||||
playbooks: List[PlaybookMetadata] = []
|
||||
if not dir_path.exists():
|
||||
return playbooks
|
||||
|
||||
# Match any SKILL.md or *_SKILL.md
|
||||
for path in dir_path.glob("**/SKILL.md"):
|
||||
pb = cls.parse_file(path)
|
||||
if pb:
|
||||
playbooks.append(pb)
|
||||
|
||||
for path in dir_path.glob("*_SKILL.md"):
|
||||
pb = cls.parse_file(path)
|
||||
if pb:
|
||||
playbooks.append(pb)
|
||||
|
||||
return playbooks
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepresentationVector:
|
||||
"""A steering direction vector or set of vectors for representation engineering."""
|
||||
|
||||
skill_id: str
|
||||
# Map from layer index (e.g. 0 to L-1) to difference activation tensor [hidden_dim]
|
||||
layer_vectors: Dict[int, torch.Tensor] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class RepresentationVectorExtractor:
|
||||
"""Extracts representation (steering) vectors from model activations using minimal pairs."""
|
||||
|
||||
def __init__(self, model: nn.Module, tokenizer: Any, device: str = "cpu") -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
def extract_steering_vector(
|
||||
self,
|
||||
skill_metadata: PlaybookMetadata,
|
||||
layers_to_extract: List[int] | None = None,
|
||||
) -> RepresentationVector:
|
||||
"""Computes the difference in hidden states for win vs lose prompts."""
|
||||
self.model.eval()
|
||||
|
||||
# Determine layers to hook
|
||||
# Pythia models store layers in model.gpt_neox.layers
|
||||
# Let's inspect model layout dynamically
|
||||
transformer_layers = []
|
||||
for name, module in self.model.named_modules():
|
||||
if re.match(r".*layers?\.\d+$", name):
|
||||
transformer_layers.append((name, module))
|
||||
|
||||
if not transformer_layers:
|
||||
# Fallback/guess for standard PyTorch modules
|
||||
logger.warning(
|
||||
"Could not automatically resolve transformer layers, will attempt default Hook paths"
|
||||
)
|
||||
|
||||
num_layers = len(transformer_layers)
|
||||
if layers_to_extract is None:
|
||||
# Default: extract from middle/late layers (e.g., last half of the network)
|
||||
layers_to_extract = list(range(num_layers // 2, num_layers))
|
||||
|
||||
# Generate minimal pairs from triggers
|
||||
win_prompts = []
|
||||
lose_prompts = []
|
||||
|
||||
for trigger in skill_metadata.trigger_examples:
|
||||
# Win prompt guides the model to invoke the playbook skill/format
|
||||
win_prompts.append(
|
||||
f"Instructions: You are acting with the following skill: {skill_metadata.name}. "
|
||||
f"Description: {skill_metadata.description}\n"
|
||||
f"Request: {trigger}\n"
|
||||
f"Output:"
|
||||
)
|
||||
# Lose prompt asks the model to respond normally
|
||||
lose_prompts.append(
|
||||
f"Instructions: Respond normally.\n" f"Request: {trigger}\n" f"Output:"
|
||||
)
|
||||
|
||||
# Temporary storage for hooked activations
|
||||
# Map: layer_idx -> list of tensors [seq_len, hidden_dim]
|
||||
win_activations: Dict[int, List[torch.Tensor]] = {
|
||||
idx: [] for idx in layers_to_extract
|
||||
}
|
||||
lose_activations: Dict[int, List[torch.Tensor]] = {
|
||||
idx: [] for idx in layers_to_extract
|
||||
}
|
||||
|
||||
# Hook function builder
|
||||
def make_hook(layer_idx: int, storage: Dict[int, List[torch.Tensor]]) -> Any:
|
||||
def hook_fn(
|
||||
module: nn.Module,
|
||||
input_t: Tuple[torch.Tensor, ...],
|
||||
output_t: torch.Tensor,
|
||||
) -> None:
|
||||
# output_t is typically [batch, seq_len, hidden_dim] or a tuple
|
||||
if isinstance(output_t, tuple):
|
||||
output_t = output_t[0]
|
||||
|
||||
# Detach and move to CPU to save GPU memory
|
||||
storage[layer_idx].append(output_t.detach().cpu())
|
||||
|
||||
return hook_fn
|
||||
|
||||
# Register hooks
|
||||
hooks = []
|
||||
for idx in layers_to_extract:
|
||||
if idx < len(transformer_layers):
|
||||
_, layer_module = transformer_layers[idx]
|
||||
h = layer_module.register_forward_hook(make_hook(idx, win_activations))
|
||||
hooks.append(h)
|
||||
|
||||
# Run forward pass for win prompts
|
||||
for prompt in win_prompts:
|
||||
raw_inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in raw_inputs.items()}
|
||||
with torch.no_grad():
|
||||
self.model(**inputs)
|
||||
|
||||
# Remove hooks and register them for lose activations
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
hooks.clear()
|
||||
|
||||
for idx in layers_to_extract:
|
||||
if idx < len(transformer_layers):
|
||||
_, layer_module = transformer_layers[idx]
|
||||
h = layer_module.register_forward_hook(make_hook(idx, lose_activations))
|
||||
hooks.append(h)
|
||||
|
||||
# Run forward pass for lose prompts
|
||||
for prompt in lose_prompts:
|
||||
raw_inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in raw_inputs.items()}
|
||||
with torch.no_grad():
|
||||
self.model(**inputs)
|
||||
|
||||
# Remove hooks
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
# Compute difference vectors
|
||||
layer_vectors: Dict[int, torch.Tensor] = {}
|
||||
for idx in layers_to_extract:
|
||||
win_tensors = win_activations[idx]
|
||||
lose_tensors = lose_activations[idx]
|
||||
|
||||
if not win_tensors or not lose_tensors:
|
||||
continue
|
||||
|
||||
diffs = []
|
||||
for win_t, lose_t in zip(win_tensors, lose_tensors):
|
||||
# win_t and lose_t are [1, seq_len, hidden_dim]
|
||||
# We can average over sequence length or take the last token (representation at decision point)
|
||||
# Let's average over the sequence length for stability
|
||||
w_mean = win_t.mean(dim=1).squeeze(0) # [hidden_dim]
|
||||
l_mean = lose_t.mean(dim=1).squeeze(0) # [hidden_dim]
|
||||
diffs.append(w_mean - l_mean)
|
||||
|
||||
# Average difference vector across all minimal pairs
|
||||
mean_diff = torch.stack(diffs).mean(dim=0)
|
||||
|
||||
# Normalize vector to unit norm for consistent steering scales
|
||||
norm = torch.norm(mean_diff)
|
||||
if norm > 1e-8:
|
||||
mean_diff = mean_diff / norm
|
||||
|
||||
layer_vectors[idx] = mean_diff
|
||||
|
||||
metadata = {
|
||||
"name": skill_metadata.name,
|
||||
"description": skill_metadata.description,
|
||||
"trigger_count": len(skill_metadata.trigger_examples),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Extracted representation vector for '{skill_metadata.name}' | {len(layer_vectors)} layers"
|
||||
)
|
||||
return RepresentationVector(
|
||||
skill_id=skill_metadata.name.lower().replace(" ", "_"),
|
||||
layer_vectors=layer_vectors,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class SkillVectorLibrary:
|
||||
"""Library containing extracted skill representation vectors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.vectors: Dict[str, RepresentationVector] = {}
|
||||
|
||||
def add_vector(self, vec: RepresentationVector) -> None:
|
||||
self.vectors[vec.skill_id] = vec
|
||||
|
||||
def get_vector(self, skill_id: str) -> RepresentationVector | None:
|
||||
return self.vectors.get(skill_id)
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""Saves the library to disk."""
|
||||
state = {
|
||||
skill_id: {
|
||||
"skill_id": vec.skill_id,
|
||||
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
|
||||
"metadata": vec.metadata,
|
||||
}
|
||||
for skill_id, vec in self.vectors.items()
|
||||
}
|
||||
torch.save(state, path)
|
||||
logger.info(
|
||||
f"Skill Vector Library saved to {path} ({len(self.vectors)} skills)"
|
||||
)
|
||||
|
||||
def load(self, path: str | Path) -> None:
|
||||
"""Loads the library from disk."""
|
||||
self.vectors.clear()
|
||||
state = torch.load(path, map_location="cpu", weights_only=False)
|
||||
for skill_id, vec_state in state.items():
|
||||
self.vectors[skill_id] = RepresentationVector(
|
||||
skill_id=vec_state["skill_id"],
|
||||
layer_vectors=vec_state["layer_vectors"],
|
||||
metadata=vec_state["metadata"],
|
||||
)
|
||||
logger.info(
|
||||
f"Skill Vector Library loaded from {path} ({len(self.vectors)} skills)"
|
||||
)
|
||||
|
||||
|
||||
class ProcessVectorLibrary:
|
||||
"""Library containing sequential process step representation vectors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.processes: Dict[str, List[RepresentationVector]] = {}
|
||||
|
||||
def add_process(self, process_id: str, steps: List[RepresentationVector]) -> None:
|
||||
self.processes[process_id] = steps
|
||||
|
||||
def get_process_step(
|
||||
self, process_id: str, step_idx: int
|
||||
) -> RepresentationVector | None:
|
||||
steps = self.processes.get(process_id)
|
||||
if steps and 0 <= step_idx < len(steps):
|
||||
return steps[step_idx]
|
||||
return None
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""Saves the library to disk."""
|
||||
state = {
|
||||
p_id: [
|
||||
{
|
||||
"skill_id": vec.skill_id,
|
||||
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
|
||||
"metadata": vec.metadata,
|
||||
}
|
||||
for vec in steps
|
||||
]
|
||||
for p_id, steps in self.processes.items()
|
||||
}
|
||||
torch.save(state, path)
|
||||
logger.info(
|
||||
f"Process Vector Library saved to {path} ({len(self.processes)} processes)"
|
||||
)
|
||||
|
||||
def load(self, path: str | Path) -> None:
|
||||
"""Loads the library from disk."""
|
||||
self.processes.clear()
|
||||
state = torch.load(path, map_location="cpu", weights_only=False)
|
||||
for p_id, steps_state in state.items():
|
||||
steps = []
|
||||
for vec_state in steps_state:
|
||||
steps.append(
|
||||
RepresentationVector(
|
||||
skill_id=vec_state["skill_id"],
|
||||
layer_vectors=vec_state["layer_vectors"],
|
||||
metadata=vec_state["metadata"],
|
||||
)
|
||||
)
|
||||
self.processes[p_id] = steps
|
||||
logger.info(
|
||||
f"Process Vector Library loaded from {path} ({len(self.processes)} processes)"
|
||||
)
|
||||
473
python/run_representation_pipeline.py
Normal file
473
python/run_representation_pipeline.py
Normal file
@@ -0,0 +1,473 @@
|
||||
"""Orchestration script to compile Skill and Process libraries, train Gating-MLP, and validate steering.
|
||||
|
||||
Discovers local playbooks, extracts representation vectors using Pythia-70m (or dummy fallback),
|
||||
trains the gating network on hidden states, and runs validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# Import our representation engineering and router modules
|
||||
import sys
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.absolute()))
|
||||
|
||||
from representation_engineering import (
|
||||
PlaybookParser,
|
||||
PlaybookMetadata,
|
||||
RepresentationVectorExtractor,
|
||||
SkillVectorLibrary,
|
||||
ProcessVectorLibrary,
|
||||
)
|
||||
from adapter_moe_router import LearnableGate, ExpertAdapterRouter
|
||||
|
||||
logger = logging.getLogger("run_representation_pipeline")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
# --- Dummy Model & Tokenizer for Offline Fallback & Testing ---
|
||||
|
||||
|
||||
class DummyTransformerLayer(nn.Module): # type: ignore[misc]
|
||||
def __init__(self, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> Tuple[torch.Tensor]:
|
||||
h = self.linear2(torch.relu(self.linear1(x)))
|
||||
return (x + h,)
|
||||
|
||||
|
||||
class DummyModel(nn.Module): # type: ignore[misc]
|
||||
def __init__(
|
||||
self, vocab_size: int = 1000, hidden_dim: int = 32, num_layers: int = 4
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_dim = hidden_dim
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_dim)
|
||||
self.layers = nn.ModuleList(
|
||||
[DummyTransformerLayer(hidden_dim) for _ in range(num_layers)]
|
||||
)
|
||||
self.lm_head = nn.Linear(hidden_dim, vocab_size)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
|
||||
x = self.embedding(input_ids)
|
||||
for layer in self.layers:
|
||||
x = layer(x)[0]
|
||||
logits = self.lm_head(x)
|
||||
|
||||
class Output:
|
||||
def __init__(self, logits: torch.Tensor) -> None:
|
||||
self.logits = logits
|
||||
|
||||
return Output(logits=logits)
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self, text: str | List[str], return_tensors: str = "pt", **kwargs: Any
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
batch_ids = []
|
||||
max_len = 0
|
||||
for t in text:
|
||||
words = t.split()
|
||||
ids = [abs(hash(w)) % 1000 for w in words]
|
||||
if not ids:
|
||||
ids = [0]
|
||||
batch_ids.append(ids)
|
||||
max_len = max(max_len, len(ids))
|
||||
|
||||
padded_ids = []
|
||||
for ids in batch_ids:
|
||||
padded_ids.append(ids + [0] * (max_len - len(ids)))
|
||||
return {"input_ids": torch.tensor(padded_ids)}
|
||||
|
||||
|
||||
# --- Helper to parse process steps from playbooks ---
|
||||
|
||||
|
||||
def parse_playbook_steps(path: Path) -> List[Tuple[str, str]]:
|
||||
"""Extracts step names and short descriptions from sequential lists in a playbook."""
|
||||
content = path.read_text(encoding="utf-8")
|
||||
steps = []
|
||||
|
||||
# Locate numbered items under headers like Fallback Procedure, Operational Protocol, etc.
|
||||
lines = content.splitlines()
|
||||
in_procedure = False
|
||||
|
||||
for line in lines:
|
||||
line_str = line.strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
if line_str.startswith("#"):
|
||||
lower_header = line_str.lower()
|
||||
in_procedure = any(
|
||||
x in lower_header
|
||||
for x in ["procedure", "protocol", "schritt", "ablauf", "loop"]
|
||||
)
|
||||
continue
|
||||
|
||||
if in_procedure:
|
||||
# Match 1. Step name: Description or just 1. Description
|
||||
match = re.match(r"^(\d+)\.\s*(.*?)$", line_str)
|
||||
if match:
|
||||
step_num = match.group(1)
|
||||
step_text = match.group(2).strip()
|
||||
# Split step name and description if separated by double asterisks or colon
|
||||
parts = re.split(r"\*\*|:\s*", step_text, maxsplit=1)
|
||||
if len(parts) > 1 and parts[0].strip():
|
||||
step_name = parts[0].strip()
|
||||
step_desc = parts[1].strip()
|
||||
else:
|
||||
step_name = f"Step_{step_num}"
|
||||
step_desc = step_text
|
||||
steps.append((step_name, step_desc))
|
||||
|
||||
return steps
|
||||
|
||||
|
||||
# --- Pipeline Orchestrator ---
|
||||
|
||||
|
||||
def run_pipeline(
|
||||
model_id: str,
|
||||
skills_dirs: List[str],
|
||||
output_dir: str,
|
||||
use_dummy: bool = False,
|
||||
epochs: int = 50,
|
||||
lr: float = 0.01,
|
||||
) -> None:
|
||||
logger.info("Initializing Representation Engineering Pipeline...")
|
||||
|
||||
# 1. Resolve output directory
|
||||
out_path = Path(output_dir)
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 2. Load model and tokenizer (with dummy fallback)
|
||||
model = None
|
||||
tokenizer = None
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if not use_dummy:
|
||||
try:
|
||||
logger.info(f"Attempting to load model '{model_id}' from HuggingFace...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
|
||||
logger.info("Model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load model from HuggingFace: {e}. Falling back to Dummy Model."
|
||||
)
|
||||
use_dummy = True
|
||||
|
||||
if use_dummy:
|
||||
logger.info(
|
||||
"Initializing lightweight Dummy Model and Tokenizer for execution/testing..."
|
||||
)
|
||||
model = DummyModel(vocab_size=1000, hidden_dim=32, num_layers=4).to(device)
|
||||
tokenizer = DummyTokenizer()
|
||||
device = "cpu"
|
||||
|
||||
assert model is not None
|
||||
assert tokenizer is not None
|
||||
|
||||
# 3. Discover playbooks
|
||||
logger.info("Scanning directories for playbooks...")
|
||||
playbooks: List[PlaybookMetadata] = []
|
||||
parsed_files: List[Path] = []
|
||||
|
||||
for s_dir in skills_dirs:
|
||||
path = Path(s_dir)
|
||||
if path.exists():
|
||||
logger.info(f"Scanning directory: {path}")
|
||||
found = PlaybookParser.parse_directory(path)
|
||||
playbooks.extend(found)
|
||||
# Find actual files for process parsing
|
||||
for p_file in path.glob("**/SKILL.md"):
|
||||
parsed_files.append(p_file)
|
||||
for p_file in path.glob("*_SKILL.md"):
|
||||
parsed_files.append(p_file)
|
||||
|
||||
if not playbooks:
|
||||
logger.warning(
|
||||
"No playbooks discovered! Creating a default mock playbook for execution..."
|
||||
)
|
||||
mock_skill = PlaybookMetadata(
|
||||
name="MockSkill",
|
||||
description="A temporary skill for pipeline validation",
|
||||
objectives=[
|
||||
"Validate the routing pipeline",
|
||||
"Verify steering functionality",
|
||||
],
|
||||
trigger_examples=[
|
||||
"Run the mock pipeline check",
|
||||
"Test activation steering on mock",
|
||||
],
|
||||
file_path="mock_SKILL.md",
|
||||
)
|
||||
playbooks.append(mock_skill)
|
||||
|
||||
logger.info(f"Discovered {len(playbooks)} playbooks.")
|
||||
|
||||
# 4. Extract Skill and Process Libraries
|
||||
skill_library = SkillVectorLibrary()
|
||||
process_library = ProcessVectorLibrary()
|
||||
extractor = RepresentationVectorExtractor(model, tokenizer, device=device)
|
||||
|
||||
# Determine model layers
|
||||
transformer_layers = []
|
||||
for name, module in model.named_modules():
|
||||
if re.match(r".*layers?\.\d+$", name):
|
||||
transformer_layers.append(name)
|
||||
num_layers = len(transformer_layers)
|
||||
logger.info(f"Detected {num_layers} transformer layers in base model.")
|
||||
|
||||
# Choose layers to extract: middle/late layers
|
||||
layers_to_extract = (
|
||||
list(range(num_layers // 2, num_layers)) if num_layers > 0 else [0]
|
||||
)
|
||||
|
||||
# Extract skill vectors
|
||||
for pb in playbooks:
|
||||
logger.info(f"Extracting vectors for skill: {pb.name}")
|
||||
vec = extractor.extract_steering_vector(pb, layers_to_extract=layers_to_extract)
|
||||
skill_library.add_vector(vec)
|
||||
|
||||
# Extract sequential process vectors if file exists
|
||||
p_path = Path(pb.file_path)
|
||||
if p_path.exists():
|
||||
steps = parse_playbook_steps(p_path)
|
||||
if steps:
|
||||
logger.info(
|
||||
f"Found {len(steps)} sequential steps in process '{pb.name}'"
|
||||
)
|
||||
step_vectors = []
|
||||
for step_name, step_desc in steps:
|
||||
step_pb = PlaybookMetadata(
|
||||
name=f"{pb.name}_{step_name}",
|
||||
description=step_desc,
|
||||
objectives=[step_desc],
|
||||
trigger_examples=pb.trigger_examples,
|
||||
file_path=str(p_path),
|
||||
)
|
||||
step_vec = extractor.extract_steering_vector(
|
||||
step_pb, layers_to_extract=layers_to_extract
|
||||
)
|
||||
step_vectors.append(step_vec)
|
||||
process_library.add_process(
|
||||
pb.name.lower().replace(" ", "_"), step_vectors
|
||||
)
|
||||
|
||||
# Save libraries
|
||||
skill_lib_path = out_path / "skill_library.pt"
|
||||
process_lib_path = out_path / "process_library.pt"
|
||||
skill_library.save(skill_lib_path)
|
||||
process_library.save(process_lib_path)
|
||||
|
||||
# 5. Train Gating-MLP on hidden states
|
||||
skill_ids = sorted(list(skill_library.vectors.keys()))
|
||||
if not skill_ids:
|
||||
logger.error("No skills available to train Gating-MLP.")
|
||||
return
|
||||
|
||||
logger.info(f"Training Gating-MLP router over {len(skill_ids)} skills...")
|
||||
|
||||
# Gating features layer (we extract hidden states from early/middle layer to predict routing)
|
||||
gate_layer_idx = layers_to_extract[0] if layers_to_extract else 0
|
||||
hidden_dim = model.hidden_dim if hasattr(model, "hidden_dim") else 32
|
||||
if hasattr(model, "config") and hasattr(model.config, "hidden_size"):
|
||||
hidden_dim = model.config.hidden_size
|
||||
|
||||
gate_net = LearnableGate(in_features=hidden_dim, num_adapters=len(skill_ids)).to(
|
||||
device
|
||||
)
|
||||
optimizer = optim.Adam(gate_net.parameters(), lr=lr)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# Collect training dataset: (hidden_state, label)
|
||||
# We pass the win prompts for each skill, hook the gate layer, and collect states
|
||||
training_data: List[Tuple[torch.Tensor, int]] = []
|
||||
|
||||
# Temporary hook to collect states
|
||||
collected_states: List[torch.Tensor] = []
|
||||
|
||||
def collect_hook(module: nn.Module, input_t: Any, output_t: Any) -> None:
|
||||
x = output_t[0] if isinstance(output_t, tuple) else output_t
|
||||
# Pool to sequence mean
|
||||
collected_states.append(x.detach().mean(dim=1).squeeze(0))
|
||||
|
||||
# Register hook on gate layer
|
||||
hook_handle = None
|
||||
target_layer_name = (
|
||||
transformer_layers[gate_layer_idx]
|
||||
if gate_layer_idx < len(transformer_layers)
|
||||
else None
|
||||
)
|
||||
|
||||
if target_layer_name:
|
||||
for name, module in model.named_modules():
|
||||
if name == target_layer_name:
|
||||
hook_handle = module.register_forward_hook(collect_hook)
|
||||
break
|
||||
|
||||
# Run win prompts to collect activations
|
||||
for label_idx, skill_id in enumerate(skill_ids):
|
||||
# Retrieve parsed metadata for this skill ID
|
||||
pb_match = next(
|
||||
(p for p in playbooks if p.name.lower().replace(" ", "_") == skill_id), None
|
||||
)
|
||||
if pb_match:
|
||||
for trigger in pb_match.trigger_examples:
|
||||
win_prompt = (
|
||||
f"Instructions: You are acting with the following skill: {pb_match.name}.\n"
|
||||
f"Request: {trigger}\nOutput:"
|
||||
)
|
||||
inputs = tokenizer(win_prompt, return_tensors="pt")
|
||||
# Move inputs to device
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
collected_states.clear()
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
if collected_states:
|
||||
state = collected_states[0].cpu() # shape [hidden_dim]
|
||||
training_data.append((state, label_idx))
|
||||
|
||||
if hook_handle:
|
||||
hook_handle.remove()
|
||||
|
||||
if not training_data:
|
||||
logger.warning(
|
||||
"Could not collect training data. Using synthetic data to train Gating-MLP."
|
||||
)
|
||||
# Fallback to random features for testing compilation flow
|
||||
for i in range(100):
|
||||
label = i % len(skill_ids)
|
||||
feat = torch.randn(hidden_dim) + (label * 2.0) # separate them a bit
|
||||
training_data.append((feat, label))
|
||||
|
||||
# Train MLP
|
||||
gate_net.train()
|
||||
X = torch.stack([x for x, y in training_data]).to(device)
|
||||
Y = torch.tensor([y for x, y in training_data]).to(device)
|
||||
|
||||
logger.info(
|
||||
f"Collected {len(training_data)} training samples. Starting optimization..."
|
||||
)
|
||||
|
||||
for epoch in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
outputs = gate_net(X)
|
||||
loss = criterion(outputs, Y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1:
|
||||
# Calculate accuracy
|
||||
_, preds = torch.max(outputs, 1)
|
||||
correct = (preds == Y).sum().item()
|
||||
acc = correct / len(Y)
|
||||
logger.info(
|
||||
f"Epoch {epoch+1:02d}/{epochs:02d} | Loss: {loss.item():.4f} | Training Accuracy: {acc * 100:.1f}%"
|
||||
)
|
||||
|
||||
# Save gate weights
|
||||
gate_weights_path = out_path / "gate_weights.pt"
|
||||
torch.save(gate_net.state_dict(), gate_weights_path)
|
||||
logger.info(f"Gating MLP weights saved to {gate_weights_path}")
|
||||
|
||||
# 6. Validate steerability & routing performance
|
||||
logger.info("Running pipeline verification...")
|
||||
router = ExpertAdapterRouter(
|
||||
base_model=model,
|
||||
skill_library=skill_library,
|
||||
process_library=process_library,
|
||||
in_features=hidden_dim,
|
||||
steering_alpha=1.0,
|
||||
)
|
||||
router.skill_gate.load_state_dict(
|
||||
torch.load(gate_weights_path, map_location=device)
|
||||
)
|
||||
router.register_hooks()
|
||||
|
||||
# Run test prompt with routing enabled
|
||||
test_prompt = "Validate this test prompt"
|
||||
inputs = tokenizer(test_prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Test that forwarding goes through hooks without crash
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs)
|
||||
logger.info("Successfully executed forward pass with dynamic activation steering.")
|
||||
|
||||
router.unregister_hooks()
|
||||
logger.info("Pipeline run completed successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="FCES Representation Engineering Pipeline"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="EleutherAI/pythia-70m",
|
||||
help="HuggingFace model identifier",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skills_dirs",
|
||||
type=str,
|
||||
default="C:/Users/Sven/Documents/svenco-knowledge/skills,C:/Users/Sven/Documents/everything-claude-code/skills",
|
||||
help="Comma-separated paths to search for playbooks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="./python/output",
|
||||
help="Directory to save compiled libraries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy", action="store_true", help="Force using lightweight dummy model"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=50, help="Gating training epochs")
|
||||
parser.add_argument(
|
||||
"--lr", type=float, default=0.01, help="Gating training learning rate"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Split skills dirs
|
||||
dirs_list = [d.strip() for d in args.skills_dirs.split(",") if d.strip()]
|
||||
|
||||
run_pipeline(
|
||||
model_id=args.model_id,
|
||||
skills_dirs=dirs_list,
|
||||
output_dir=args.output_dir,
|
||||
use_dummy=args.use_dummy,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
)
|
||||
251
tests/test_representation_engineering.py
Normal file
251
tests/test_representation_engineering.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Unit tests for Representation Engineering and Vector Library Compilation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Ensure python directory is in path
|
||||
sys.path.append(
|
||||
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||
)
|
||||
|
||||
from representation_engineering import (
|
||||
PlaybookParser,
|
||||
PlaybookMetadata,
|
||||
RepresentationVector,
|
||||
RepresentationVectorExtractor,
|
||||
SkillVectorLibrary,
|
||||
ProcessVectorLibrary,
|
||||
)
|
||||
from adapter_moe_router import ExpertAdapterRouter
|
||||
|
||||
|
||||
# --- Mock Classes for Testing ---
|
||||
|
||||
|
||||
class SimpleTransformerLayer(nn.Module): # type: ignore[misc]
|
||||
def __init__(self, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
||||
with torch.no_grad():
|
||||
self.linear.weight.fill_(0.1)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> Tuple[torch.Tensor]:
|
||||
return (x + self.linear(x),)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||
def __init__(self, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
# We name layers matching .layers.\d+$ to verify registration
|
||||
self.layers = nn.ModuleList(
|
||||
[SimpleTransformerLayer(hidden_dim), SimpleTransformerLayer(hidden_dim)]
|
||||
)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
|
||||
# Simple embedding emulation: shape [batch, seq_len, hidden_dim]
|
||||
# input_ids shape: [batch, seq_len]
|
||||
batch_size, seq_len = input_ids.shape
|
||||
x = (
|
||||
torch.ones(batch_size, seq_len, self.hidden_dim, dtype=torch.float32)
|
||||
* input_ids.unsqueeze(-1).float()
|
||||
)
|
||||
for layer in self.layers:
|
||||
x = layer(x)[0]
|
||||
|
||||
class Output:
|
||||
def __init__(self, logits: torch.Tensor) -> None:
|
||||
self.logits = logits
|
||||
|
||||
return Output(logits=x)
|
||||
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __call__(
|
||||
self, text: str | List[str], return_tensors: str = "pt", **kwargs: Any
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
val = 2 if any("skill" in t.lower() for t in text) else 1
|
||||
return {"input_ids": torch.ones(len(text), 5, dtype=torch.long) * val}
|
||||
|
||||
|
||||
# --- Test Cases ---
|
||||
|
||||
|
||||
class TestRepresentationEngineering(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model = SimpleModel(hidden_dim=16)
|
||||
self.tokenizer = SimpleTokenizer()
|
||||
self.extractor = RepresentationVectorExtractor(
|
||||
self.model, self.tokenizer, device="cpu"
|
||||
)
|
||||
|
||||
def test_playbook_parser(self) -> None:
|
||||
# Create a mock SKILL.md content
|
||||
mock_content = """---
|
||||
name: Mock Test Skill
|
||||
description: >-
|
||||
This is a description
|
||||
with multiline text.
|
||||
---
|
||||
# Skill: Mock Test Skill
|
||||
|
||||
## Objective
|
||||
- Analyze inputs dynamically.
|
||||
- Perform steering optimization.
|
||||
|
||||
## Activation Trigger
|
||||
- Use this when testing the parser.
|
||||
- Trigger example two.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
file_path = Path(tmpdir) / "SKILL.md"
|
||||
file_path.write_text(mock_content, encoding="utf-8")
|
||||
|
||||
pb = PlaybookParser.parse_file(file_path)
|
||||
assert pb is not None
|
||||
self.assertEqual(pb.name, "Mock Test Skill")
|
||||
self.assertEqual(
|
||||
pb.description, "This is a description with multiline text."
|
||||
)
|
||||
self.assertIn("Analyze inputs dynamically.", pb.objectives)
|
||||
self.assertIn("Use this when testing the parser.", pb.trigger_examples)
|
||||
self.assertIn("Trigger example two.", pb.trigger_examples)
|
||||
|
||||
def test_vector_extraction(self) -> None:
|
||||
mock_pb = PlaybookMetadata(
|
||||
name="TestSteering",
|
||||
description="Steering description",
|
||||
objectives=["Objective 1"],
|
||||
trigger_examples=["Trigger 1"],
|
||||
file_path="mock_path.md",
|
||||
)
|
||||
# Extract from layer index 1
|
||||
vec = self.extractor.extract_steering_vector(mock_pb, layers_to_extract=[1])
|
||||
self.assertEqual(vec.skill_id, "teststeering")
|
||||
self.assertIn(1, vec.layer_vectors)
|
||||
self.assertEqual(vec.layer_vectors[1].shape, (16,))
|
||||
# Assert normalized to unit norm
|
||||
self.assertAlmostEqual(torch.norm(vec.layer_vectors[1]).item(), 1.0, places=5)
|
||||
|
||||
def test_libraries_save_and_load(self) -> None:
|
||||
skill_lib = SkillVectorLibrary()
|
||||
process_lib = ProcessVectorLibrary()
|
||||
|
||||
# Add mock vectors
|
||||
v1 = RepresentationVector(
|
||||
skill_id="skill_a",
|
||||
layer_vectors={0: torch.ones(16), 1: torch.zeros(16)},
|
||||
metadata={"name": "Skill A"},
|
||||
)
|
||||
skill_lib.add_vector(v1)
|
||||
|
||||
v2 = RepresentationVector(
|
||||
skill_id="step_1",
|
||||
layer_vectors={0: torch.ones(16) * 0.5},
|
||||
metadata={"name": "Step 1"},
|
||||
)
|
||||
process_lib.add_process("process_a", [v2])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
skill_path = Path(tmpdir) / "skills.pt"
|
||||
proc_path = Path(tmpdir) / "processes.pt"
|
||||
|
||||
skill_lib.save(skill_path)
|
||||
process_lib.save(proc_path)
|
||||
|
||||
self.assertTrue(skill_path.exists())
|
||||
self.assertTrue(proc_path.exists())
|
||||
|
||||
# Load into new libraries
|
||||
new_skill_lib = SkillVectorLibrary()
|
||||
new_proc_lib = ProcessVectorLibrary()
|
||||
|
||||
new_skill_lib.load(skill_path)
|
||||
new_proc_lib.load(proc_path)
|
||||
|
||||
loaded_v1 = new_skill_lib.get_vector("skill_a")
|
||||
assert loaded_v1 is not None
|
||||
self.assertEqual(loaded_v1.skill_id, "skill_a")
|
||||
self.assertTrue(torch.allclose(loaded_v1.layer_vectors[0], torch.ones(16)))
|
||||
|
||||
loaded_v2 = new_proc_lib.get_process_step("process_a", 0)
|
||||
assert loaded_v2 is not None
|
||||
self.assertEqual(loaded_v2.skill_id, "step_1")
|
||||
self.assertTrue(
|
||||
torch.allclose(loaded_v2.layer_vectors[0], torch.ones(16) * 0.5)
|
||||
)
|
||||
|
||||
def test_router_activation_steering(self) -> None:
|
||||
# Create a skill library and add a steering vector
|
||||
skill_lib = SkillVectorLibrary()
|
||||
v = RepresentationVector(
|
||||
skill_id="skill_steer",
|
||||
layer_vectors={0: torch.ones(16) * 0.5, 1: torch.ones(16) * -0.5},
|
||||
metadata={"name": "Steer"},
|
||||
)
|
||||
skill_lib.add_vector(v)
|
||||
|
||||
# Setup router
|
||||
router = ExpertAdapterRouter(
|
||||
base_model=self.model,
|
||||
skill_library=skill_lib,
|
||||
in_features=16,
|
||||
steering_alpha=2.0,
|
||||
steering_mode="prompt",
|
||||
)
|
||||
|
||||
# Test 1: Forward without hooks (baseline)
|
||||
x = torch.ones(1, 4, dtype=torch.long) # batch=1, seq_len=4
|
||||
with torch.no_grad():
|
||||
out_base = self.model(x).logits
|
||||
|
||||
# Test 2: Register hooks and set active routing to skill_steer (index 0)
|
||||
router.register_hooks()
|
||||
self.assertEqual(len(router.hooks), 2) # Should hook both layers
|
||||
|
||||
priors = torch.tensor([1.0])
|
||||
router.set_active_routing(priors)
|
||||
|
||||
with torch.no_grad():
|
||||
out_steered = self.model(x).logits
|
||||
|
||||
# Verify that steering changed the output
|
||||
self.assertFalse(torch.allclose(out_base, out_steered))
|
||||
|
||||
# Test 3: Process sequential routing
|
||||
router.unregister_hooks()
|
||||
proc_lib = ProcessVectorLibrary()
|
||||
v_proc = RepresentationVector(
|
||||
skill_id="step_a",
|
||||
layer_vectors={0: torch.ones(16) * 10.0}, # large steering value
|
||||
metadata={"name": "Step A"},
|
||||
)
|
||||
proc_lib.add_process("my_workflow", [v_proc])
|
||||
|
||||
router.process_library = proc_lib
|
||||
router.active_process_id = "my_workflow"
|
||||
router.active_process_step = 0
|
||||
router.register_hooks()
|
||||
|
||||
with torch.no_grad():
|
||||
out_proc = self.model(x).logits
|
||||
|
||||
self.assertFalse(torch.allclose(out_base, out_proc))
|
||||
router.unregister_hooks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user