Files
FCES-native/python/representation_engineering.py

414 lines
15 KiB
Python

"""Representation Engineering and Vector Extraction from Playbooks.
Parses playbooks (SKILL.md), extracts minimal pairs, and computes steerable
activation vectors or QLoRA adapters to build Skill and Process Vector Libraries.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
logger = logging.getLogger("representation_engineering")
logging.basicConfig(level=logging.INFO)
@dataclass
class PlaybookMetadata:
"""Metadata and examples parsed from a SKILL.md playbook."""
name: str
description: str
objectives: List[str] = field(default_factory=list)
trigger_examples: List[str] = field(default_factory=list)
file_path: str = ""
class PlaybookParser:
"""Parses SKILL.md files to extract structured metadata and trigger examples."""
@staticmethod
def parse_file(path: str | Path) -> PlaybookMetadata | None:
"""Parses a single SKILL.md file."""
path = Path(path)
if not path.exists():
logger.warning(f"File not found: {path}")
return None
content = path.read_text(encoding="utf-8")
# Parse frontmatter if present
name = path.parent.name if path.parent else "unknown"
description = ""
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if frontmatter_match:
fm_text = frontmatter_match.group(1)
name_match = re.search(r"^name:\s*(.*?)$", fm_text, re.MULTILINE)
if name_match:
name = name_match.group(1).strip()
desc_match = re.search(r"^description:\s*(.*?)$", fm_text, re.MULTILINE)
if desc_match:
description = desc_match.group(1).strip()
# Remove multiline YAML indicators if any
description = re.sub(r"^>-?\s*", "", description)
description = re.sub(r"\n\s+", " ", description)
# Parse headers if no name/description in frontmatter
if name == "unknown" or not description:
title_match = re.search(r"^#\s+(.*?)$", content, re.MULTILINE)
if title_match and name == "unknown":
name = title_match.group(1).strip()
# Remove prefixes like "Skill:" or "SOP:"
name = re.sub(
r"^(Skill|SOP|Playbook):\s*", "", name, flags=re.IGNORECASE
)
# Objectives and triggers
objectives: List[str] = []
triggers: List[str] = []
# Simple regex searches for bullet points
lines = content.splitlines()
in_objectives = False
in_triggers = False
for line in lines:
line_str = line.strip()
if not line_str:
continue
# Section detection
if line_str.startswith("#"):
in_objectives = (
"objective" in line_str.lower() or "ziel" in line_str.lower()
)
in_triggers = any(
x in line_str.lower()
for x in ["trigger", "when to use", "examples", "beispiele"]
)
continue
if in_objectives and (
line_str.startswith("-")
or line_str.startswith("*")
or re.match(r"^\d+\.", line_str)
):
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
objectives.append(clean_line)
if in_triggers and (
line_str.startswith("-")
or line_str.startswith("*")
or re.match(r"^\d+\.", line_str)
):
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
# Filter out generic instructions
if len(clean_line) > 5 and not clean_line.lower().startswith("do not"):
triggers.append(clean_line)
# Fallback if no triggers found
if not triggers:
triggers = [
f"Apply the {name} skill to handle this task.",
f"How do I use {name} here?",
]
return PlaybookMetadata(
name=name,
description=description or f"Playbook for {name}",
objectives=objectives,
trigger_examples=triggers,
file_path=str(path.absolute()),
)
@classmethod
def parse_directory(cls, dir_path: str | Path) -> List[PlaybookMetadata]:
"""Scans a directory for SKILL.md or matches files and parses them."""
dir_path = Path(dir_path)
playbooks: List[PlaybookMetadata] = []
if not dir_path.exists():
return playbooks
# Match any SKILL.md or *_SKILL.md
for path in dir_path.glob("**/SKILL.md"):
pb = cls.parse_file(path)
if pb:
playbooks.append(pb)
for path in dir_path.glob("*_SKILL.md"):
pb = cls.parse_file(path)
if pb:
playbooks.append(pb)
return playbooks
@dataclass
class RepresentationVector:
"""A steering direction vector or set of vectors for representation engineering."""
skill_id: str
# Map from layer index (e.g. 0 to L-1) to difference activation tensor [hidden_dim]
layer_vectors: Dict[int, torch.Tensor] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
class RepresentationVectorExtractor:
"""Extracts representation (steering) vectors from model activations using minimal pairs."""
def __init__(self, model: nn.Module, tokenizer: Any, device: str = "cpu") -> None:
self.model = model
self.tokenizer = tokenizer
self.device = device
def extract_steering_vector(
self,
skill_metadata: PlaybookMetadata,
layers_to_extract: List[int] | None = None,
) -> RepresentationVector:
"""Computes the difference in hidden states for win vs lose prompts."""
self.model.eval()
# Determine layers to hook
# Pythia models store layers in model.gpt_neox.layers
# Let's inspect model layout dynamically
transformer_layers = []
for name, module in self.model.named_modules():
if re.match(r".*layers?\.\d+$", name):
transformer_layers.append((name, module))
if not transformer_layers:
# Fallback/guess for standard PyTorch modules
logger.warning(
"Could not automatically resolve transformer layers, will attempt default Hook paths"
)
num_layers = len(transformer_layers)
if layers_to_extract is None:
# Default: extract from middle/late layers (e.g., last half of the network)
layers_to_extract = list(range(num_layers // 2, num_layers))
# Generate minimal pairs from triggers
win_prompts = []
lose_prompts = []
for trigger in skill_metadata.trigger_examples:
# Win prompt guides the model to invoke the playbook skill/format
win_prompts.append(
f"Instructions: You are acting with the following skill: {skill_metadata.name}. "
f"Description: {skill_metadata.description}\n"
f"Request: {trigger}\n"
f"Output:"
)
# Lose prompt asks the model to respond normally
lose_prompts.append(
f"Instructions: Respond normally.\n" f"Request: {trigger}\n" f"Output:"
)
# Temporary storage for hooked activations
# Map: layer_idx -> list of tensors [seq_len, hidden_dim]
win_activations: Dict[int, List[torch.Tensor]] = {
idx: [] for idx in layers_to_extract
}
lose_activations: Dict[int, List[torch.Tensor]] = {
idx: [] for idx in layers_to_extract
}
# Hook function builder
def make_hook(layer_idx: int, storage: Dict[int, List[torch.Tensor]]) -> Any:
def hook_fn(
module: nn.Module,
input_t: Tuple[torch.Tensor, ...],
output_t: torch.Tensor,
) -> None:
# output_t is typically [batch, seq_len, hidden_dim] or a tuple
if isinstance(output_t, tuple):
output_t = output_t[0]
# Detach and move to CPU to save GPU memory
storage[layer_idx].append(output_t.detach().cpu())
return hook_fn
# Register hooks
hooks = []
for idx in layers_to_extract:
if idx < len(transformer_layers):
_, layer_module = transformer_layers[idx]
h = layer_module.register_forward_hook(make_hook(idx, win_activations))
hooks.append(h)
# Run forward pass for win prompts
for prompt in win_prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
self.model(**inputs)
# Remove hooks and register them for lose activations
for h in hooks:
h.remove()
hooks.clear()
for idx in layers_to_extract:
if idx < len(transformer_layers):
_, layer_module = transformer_layers[idx]
h = layer_module.register_forward_hook(make_hook(idx, lose_activations))
hooks.append(h)
# Run forward pass for lose prompts
for prompt in lose_prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
self.model(**inputs)
# Remove hooks
for h in hooks:
h.remove()
# Compute difference vectors
layer_vectors: Dict[int, torch.Tensor] = {}
for idx in layers_to_extract:
win_tensors = win_activations[idx]
lose_tensors = lose_activations[idx]
if not win_tensors or not lose_tensors:
continue
diffs = []
for win_t, lose_t in zip(win_tensors, lose_tensors):
# win_t and lose_t are [1, seq_len, hidden_dim]
# We can average over sequence length or take the last token (representation at decision point)
# Let's average over the sequence length for stability
w_mean = win_t.mean(dim=1).squeeze(0) # [hidden_dim]
l_mean = lose_t.mean(dim=1).squeeze(0) # [hidden_dim]
diffs.append(w_mean - l_mean)
# Average difference vector across all minimal pairs
mean_diff = torch.stack(diffs).mean(dim=0)
# Normalize vector to unit norm for consistent steering scales
norm = torch.norm(mean_diff)
if norm > 1e-8:
mean_diff = mean_diff / norm
layer_vectors[idx] = mean_diff
metadata = {
"name": skill_metadata.name,
"description": skill_metadata.description,
"trigger_count": len(skill_metadata.trigger_examples),
}
logger.info(
f"Extracted representation vector for '{skill_metadata.name}' | {len(layer_vectors)} layers"
)
return RepresentationVector(
skill_id=skill_metadata.name.lower().replace(" ", "_"),
layer_vectors=layer_vectors,
metadata=metadata,
)
class SkillVectorLibrary:
"""Library containing extracted skill representation vectors."""
def __init__(self) -> None:
self.vectors: Dict[str, RepresentationVector] = {}
def add_vector(self, vec: RepresentationVector) -> None:
self.vectors[vec.skill_id] = vec
def get_vector(self, skill_id: str) -> RepresentationVector | None:
return self.vectors.get(skill_id)
def save(self, path: str | Path) -> None:
"""Saves the library to disk."""
state = {
skill_id: {
"skill_id": vec.skill_id,
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
"metadata": vec.metadata,
}
for skill_id, vec in self.vectors.items()
}
torch.save(state, path)
logger.info(
f"Skill Vector Library saved to {path} ({len(self.vectors)} skills)"
)
def load(self, path: str | Path) -> None:
"""Loads the library from disk."""
self.vectors.clear()
state = torch.load(path, map_location="cpu", weights_only=False)
for skill_id, vec_state in state.items():
self.vectors[skill_id] = RepresentationVector(
skill_id=vec_state["skill_id"],
layer_vectors=vec_state["layer_vectors"],
metadata=vec_state["metadata"],
)
logger.info(
f"Skill Vector Library loaded from {path} ({len(self.vectors)} skills)"
)
class ProcessVectorLibrary:
"""Library containing sequential process step representation vectors."""
def __init__(self) -> None:
self.processes: Dict[str, List[RepresentationVector]] = {}
def add_process(self, process_id: str, steps: List[RepresentationVector]) -> None:
self.processes[process_id] = steps
def get_process_step(
self, process_id: str, step_idx: int
) -> RepresentationVector | None:
steps = self.processes.get(process_id)
if steps and 0 <= step_idx < len(steps):
return steps[step_idx]
return None
def save(self, path: str | Path) -> None:
"""Saves the library to disk."""
state = {
p_id: [
{
"skill_id": vec.skill_id,
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
"metadata": vec.metadata,
}
for vec in steps
]
for p_id, steps in self.processes.items()
}
torch.save(state, path)
logger.info(
f"Process Vector Library saved to {path} ({len(self.processes)} processes)"
)
def load(self, path: str | Path) -> None:
"""Loads the library from disk."""
self.processes.clear()
state = torch.load(path, map_location="cpu", weights_only=False)
for p_id, steps_state in state.items():
steps = []
for vec_state in steps_state:
steps.append(
RepresentationVector(
skill_id=vec_state["skill_id"],
layer_vectors=vec_state["layer_vectors"],
metadata=vec_state["metadata"],
)
)
self.processes[p_id] = steps
logger.info(
f"Process Vector Library loaded from {path} ({len(self.processes)} processes)"
)