Fix mypy and ruff formatting in representation_engineering.py
This commit is contained in:
413
python/representation_engineering.py
Normal file
413
python/representation_engineering.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""Representation Engineering and Vector Extraction from Playbooks.
|
||||
|
||||
Parses playbooks (SKILL.md), extracts minimal pairs, and computes steerable
|
||||
activation vectors or QLoRA adapters to build Skill and Process Vector Libraries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger("representation_engineering")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaybookMetadata:
|
||||
"""Metadata and examples parsed from a SKILL.md playbook."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
objectives: List[str] = field(default_factory=list)
|
||||
trigger_examples: List[str] = field(default_factory=list)
|
||||
file_path: str = ""
|
||||
|
||||
|
||||
class PlaybookParser:
|
||||
"""Parses SKILL.md files to extract structured metadata and trigger examples."""
|
||||
|
||||
@staticmethod
|
||||
def parse_file(path: str | Path) -> PlaybookMetadata | None:
|
||||
"""Parses a single SKILL.md file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
logger.warning(f"File not found: {path}")
|
||||
return None
|
||||
|
||||
content = path.read_text(encoding="utf-8")
|
||||
|
||||
# Parse frontmatter if present
|
||||
name = path.parent.name if path.parent else "unknown"
|
||||
description = ""
|
||||
|
||||
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if frontmatter_match:
|
||||
fm_text = frontmatter_match.group(1)
|
||||
name_match = re.search(r"^name:\s*(.*?)$", fm_text, re.MULTILINE)
|
||||
if name_match:
|
||||
name = name_match.group(1).strip()
|
||||
desc_match = re.search(r"^description:\s*(.*?)$", fm_text, re.MULTILINE)
|
||||
if desc_match:
|
||||
description = desc_match.group(1).strip()
|
||||
# Remove multiline YAML indicators if any
|
||||
description = re.sub(r"^>-?\s*", "", description)
|
||||
description = re.sub(r"\n\s+", " ", description)
|
||||
|
||||
# Parse headers if no name/description in frontmatter
|
||||
if name == "unknown" or not description:
|
||||
title_match = re.search(r"^#\s+(.*?)$", content, re.MULTILINE)
|
||||
if title_match and name == "unknown":
|
||||
name = title_match.group(1).strip()
|
||||
# Remove prefixes like "Skill:" or "SOP:"
|
||||
name = re.sub(
|
||||
r"^(Skill|SOP|Playbook):\s*", "", name, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Objectives and triggers
|
||||
objectives: List[str] = []
|
||||
triggers: List[str] = []
|
||||
|
||||
# Simple regex searches for bullet points
|
||||
lines = content.splitlines()
|
||||
in_objectives = False
|
||||
in_triggers = False
|
||||
|
||||
for line in lines:
|
||||
line_str = line.strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
# Section detection
|
||||
if line_str.startswith("#"):
|
||||
in_objectives = (
|
||||
"objective" in line_str.lower() or "ziel" in line_str.lower()
|
||||
)
|
||||
in_triggers = any(
|
||||
x in line_str.lower()
|
||||
for x in ["trigger", "when to use", "examples", "beispiele"]
|
||||
)
|
||||
continue
|
||||
|
||||
if in_objectives and (
|
||||
line_str.startswith("-")
|
||||
or line_str.startswith("*")
|
||||
or re.match(r"^\d+\.", line_str)
|
||||
):
|
||||
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
|
||||
objectives.append(clean_line)
|
||||
|
||||
if in_triggers and (
|
||||
line_str.startswith("-")
|
||||
or line_str.startswith("*")
|
||||
or re.match(r"^\d+\.", line_str)
|
||||
):
|
||||
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
|
||||
# Filter out generic instructions
|
||||
if len(clean_line) > 5 and not clean_line.lower().startswith("do not"):
|
||||
triggers.append(clean_line)
|
||||
|
||||
# Fallback if no triggers found
|
||||
if not triggers:
|
||||
triggers = [
|
||||
f"Apply the {name} skill to handle this task.",
|
||||
f"How do I use {name} here?",
|
||||
]
|
||||
|
||||
return PlaybookMetadata(
|
||||
name=name,
|
||||
description=description or f"Playbook for {name}",
|
||||
objectives=objectives,
|
||||
trigger_examples=triggers,
|
||||
file_path=str(path.absolute()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_directory(cls, dir_path: str | Path) -> List[PlaybookMetadata]:
|
||||
"""Scans a directory for SKILL.md or matches files and parses them."""
|
||||
dir_path = Path(dir_path)
|
||||
playbooks: List[PlaybookMetadata] = []
|
||||
if not dir_path.exists():
|
||||
return playbooks
|
||||
|
||||
# Match any SKILL.md or *_SKILL.md
|
||||
for path in dir_path.glob("**/SKILL.md"):
|
||||
pb = cls.parse_file(path)
|
||||
if pb:
|
||||
playbooks.append(pb)
|
||||
|
||||
for path in dir_path.glob("*_SKILL.md"):
|
||||
pb = cls.parse_file(path)
|
||||
if pb:
|
||||
playbooks.append(pb)
|
||||
|
||||
return playbooks
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepresentationVector:
|
||||
"""A steering direction vector or set of vectors for representation engineering."""
|
||||
|
||||
skill_id: str
|
||||
# Map from layer index (e.g. 0 to L-1) to difference activation tensor [hidden_dim]
|
||||
layer_vectors: Dict[int, torch.Tensor] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class RepresentationVectorExtractor:
|
||||
"""Extracts representation (steering) vectors from model activations using minimal pairs."""
|
||||
|
||||
def __init__(self, model: nn.Module, tokenizer: Any, device: str = "cpu") -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
def extract_steering_vector(
|
||||
self,
|
||||
skill_metadata: PlaybookMetadata,
|
||||
layers_to_extract: List[int] | None = None,
|
||||
) -> RepresentationVector:
|
||||
"""Computes the difference in hidden states for win vs lose prompts."""
|
||||
self.model.eval()
|
||||
|
||||
# Determine layers to hook
|
||||
# Pythia models store layers in model.gpt_neox.layers
|
||||
# Let's inspect model layout dynamically
|
||||
transformer_layers = []
|
||||
for name, module in self.model.named_modules():
|
||||
if re.match(r".*layers?\.\d+$", name):
|
||||
transformer_layers.append((name, module))
|
||||
|
||||
if not transformer_layers:
|
||||
# Fallback/guess for standard PyTorch modules
|
||||
logger.warning(
|
||||
"Could not automatically resolve transformer layers, will attempt default Hook paths"
|
||||
)
|
||||
|
||||
num_layers = len(transformer_layers)
|
||||
if layers_to_extract is None:
|
||||
# Default: extract from middle/late layers (e.g., last half of the network)
|
||||
layers_to_extract = list(range(num_layers // 2, num_layers))
|
||||
|
||||
# Generate minimal pairs from triggers
|
||||
win_prompts = []
|
||||
lose_prompts = []
|
||||
|
||||
for trigger in skill_metadata.trigger_examples:
|
||||
# Win prompt guides the model to invoke the playbook skill/format
|
||||
win_prompts.append(
|
||||
f"Instructions: You are acting with the following skill: {skill_metadata.name}. "
|
||||
f"Description: {skill_metadata.description}\n"
|
||||
f"Request: {trigger}\n"
|
||||
f"Output:"
|
||||
)
|
||||
# Lose prompt asks the model to respond normally
|
||||
lose_prompts.append(
|
||||
f"Instructions: Respond normally.\n" f"Request: {trigger}\n" f"Output:"
|
||||
)
|
||||
|
||||
# Temporary storage for hooked activations
|
||||
# Map: layer_idx -> list of tensors [seq_len, hidden_dim]
|
||||
win_activations: Dict[int, List[torch.Tensor]] = {
|
||||
idx: [] for idx in layers_to_extract
|
||||
}
|
||||
lose_activations: Dict[int, List[torch.Tensor]] = {
|
||||
idx: [] for idx in layers_to_extract
|
||||
}
|
||||
|
||||
# Hook function builder
|
||||
def make_hook(layer_idx: int, storage: Dict[int, List[torch.Tensor]]) -> Any:
|
||||
def hook_fn(
|
||||
module: nn.Module,
|
||||
input_t: Tuple[torch.Tensor, ...],
|
||||
output_t: torch.Tensor,
|
||||
) -> None:
|
||||
# output_t is typically [batch, seq_len, hidden_dim] or a tuple
|
||||
if isinstance(output_t, tuple):
|
||||
output_t = output_t[0]
|
||||
|
||||
# Detach and move to CPU to save GPU memory
|
||||
storage[layer_idx].append(output_t.detach().cpu())
|
||||
|
||||
return hook_fn
|
||||
|
||||
# Register hooks
|
||||
hooks = []
|
||||
for idx in layers_to_extract:
|
||||
if idx < len(transformer_layers):
|
||||
_, layer_module = transformer_layers[idx]
|
||||
h = layer_module.register_forward_hook(make_hook(idx, win_activations))
|
||||
hooks.append(h)
|
||||
|
||||
# Run forward pass for win prompts
|
||||
for prompt in win_prompts:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.no_grad():
|
||||
self.model(**inputs)
|
||||
|
||||
# Remove hooks and register them for lose activations
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
hooks.clear()
|
||||
|
||||
for idx in layers_to_extract:
|
||||
if idx < len(transformer_layers):
|
||||
_, layer_module = transformer_layers[idx]
|
||||
h = layer_module.register_forward_hook(make_hook(idx, lose_activations))
|
||||
hooks.append(h)
|
||||
|
||||
# Run forward pass for lose prompts
|
||||
for prompt in lose_prompts:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.no_grad():
|
||||
self.model(**inputs)
|
||||
|
||||
# Remove hooks
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
# Compute difference vectors
|
||||
layer_vectors: Dict[int, torch.Tensor] = {}
|
||||
for idx in layers_to_extract:
|
||||
win_tensors = win_activations[idx]
|
||||
lose_tensors = lose_activations[idx]
|
||||
|
||||
if not win_tensors or not lose_tensors:
|
||||
continue
|
||||
|
||||
diffs = []
|
||||
for win_t, lose_t in zip(win_tensors, lose_tensors):
|
||||
# win_t and lose_t are [1, seq_len, hidden_dim]
|
||||
# We can average over sequence length or take the last token (representation at decision point)
|
||||
# Let's average over the sequence length for stability
|
||||
w_mean = win_t.mean(dim=1).squeeze(0) # [hidden_dim]
|
||||
l_mean = lose_t.mean(dim=1).squeeze(0) # [hidden_dim]
|
||||
diffs.append(w_mean - l_mean)
|
||||
|
||||
# Average difference vector across all minimal pairs
|
||||
mean_diff = torch.stack(diffs).mean(dim=0)
|
||||
|
||||
# Normalize vector to unit norm for consistent steering scales
|
||||
norm = torch.norm(mean_diff)
|
||||
if norm > 1e-8:
|
||||
mean_diff = mean_diff / norm
|
||||
|
||||
layer_vectors[idx] = mean_diff
|
||||
|
||||
metadata = {
|
||||
"name": skill_metadata.name,
|
||||
"description": skill_metadata.description,
|
||||
"trigger_count": len(skill_metadata.trigger_examples),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Extracted representation vector for '{skill_metadata.name}' | {len(layer_vectors)} layers"
|
||||
)
|
||||
return RepresentationVector(
|
||||
skill_id=skill_metadata.name.lower().replace(" ", "_"),
|
||||
layer_vectors=layer_vectors,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class SkillVectorLibrary:
|
||||
"""Library containing extracted skill representation vectors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.vectors: Dict[str, RepresentationVector] = {}
|
||||
|
||||
def add_vector(self, vec: RepresentationVector) -> None:
|
||||
self.vectors[vec.skill_id] = vec
|
||||
|
||||
def get_vector(self, skill_id: str) -> RepresentationVector | None:
|
||||
return self.vectors.get(skill_id)
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""Saves the library to disk."""
|
||||
state = {
|
||||
skill_id: {
|
||||
"skill_id": vec.skill_id,
|
||||
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
|
||||
"metadata": vec.metadata,
|
||||
}
|
||||
for skill_id, vec in self.vectors.items()
|
||||
}
|
||||
torch.save(state, path)
|
||||
logger.info(
|
||||
f"Skill Vector Library saved to {path} ({len(self.vectors)} skills)"
|
||||
)
|
||||
|
||||
def load(self, path: str | Path) -> None:
|
||||
"""Loads the library from disk."""
|
||||
self.vectors.clear()
|
||||
state = torch.load(path, map_location="cpu", weights_only=False)
|
||||
for skill_id, vec_state in state.items():
|
||||
self.vectors[skill_id] = RepresentationVector(
|
||||
skill_id=vec_state["skill_id"],
|
||||
layer_vectors=vec_state["layer_vectors"],
|
||||
metadata=vec_state["metadata"],
|
||||
)
|
||||
logger.info(
|
||||
f"Skill Vector Library loaded from {path} ({len(self.vectors)} skills)"
|
||||
)
|
||||
|
||||
|
||||
class ProcessVectorLibrary:
|
||||
"""Library containing sequential process step representation vectors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.processes: Dict[str, List[RepresentationVector]] = {}
|
||||
|
||||
def add_process(self, process_id: str, steps: List[RepresentationVector]) -> None:
|
||||
self.processes[process_id] = steps
|
||||
|
||||
def get_process_step(
|
||||
self, process_id: str, step_idx: int
|
||||
) -> RepresentationVector | None:
|
||||
steps = self.processes.get(process_id)
|
||||
if steps and 0 <= step_idx < len(steps):
|
||||
return steps[step_idx]
|
||||
return None
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""Saves the library to disk."""
|
||||
state = {
|
||||
p_id: [
|
||||
{
|
||||
"skill_id": vec.skill_id,
|
||||
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
|
||||
"metadata": vec.metadata,
|
||||
}
|
||||
for vec in steps
|
||||
]
|
||||
for p_id, steps in self.processes.items()
|
||||
}
|
||||
torch.save(state, path)
|
||||
logger.info(
|
||||
f"Process Vector Library saved to {path} ({len(self.processes)} processes)"
|
||||
)
|
||||
|
||||
def load(self, path: str | Path) -> None:
|
||||
"""Loads the library from disk."""
|
||||
self.processes.clear()
|
||||
state = torch.load(path, map_location="cpu", weights_only=False)
|
||||
for p_id, steps_state in state.items():
|
||||
steps = []
|
||||
for vec_state in steps_state:
|
||||
steps.append(
|
||||
RepresentationVector(
|
||||
skill_id=vec_state["skill_id"],
|
||||
layer_vectors=vec_state["layer_vectors"],
|
||||
metadata=vec_state["metadata"],
|
||||
)
|
||||
)
|
||||
self.processes[p_id] = steps
|
||||
logger.info(
|
||||
f"Process Vector Library loaded from {path} ({len(self.processes)} processes)"
|
||||
)
|
||||
Reference in New Issue
Block a user