Fix mypy and ruff formatting in representation_engineering.py

This commit is contained in:
AI-anonymous
2026-05-23 00:09:51 +02:00
parent 306372bb5b
commit c6ba37dc39

View File

@@ -0,0 +1,413 @@
"""Representation Engineering and Vector Extraction from Playbooks.
Parses playbooks (SKILL.md), extracts minimal pairs, and computes steerable
activation vectors or QLoRA adapters to build Skill and Process Vector Libraries.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
logger = logging.getLogger("representation_engineering")
logging.basicConfig(level=logging.INFO)
@dataclass
class PlaybookMetadata:
"""Metadata and examples parsed from a SKILL.md playbook."""
name: str
description: str
objectives: List[str] = field(default_factory=list)
trigger_examples: List[str] = field(default_factory=list)
file_path: str = ""
class PlaybookParser:
"""Parses SKILL.md files to extract structured metadata and trigger examples."""
@staticmethod
def parse_file(path: str | Path) -> PlaybookMetadata | None:
"""Parses a single SKILL.md file."""
path = Path(path)
if not path.exists():
logger.warning(f"File not found: {path}")
return None
content = path.read_text(encoding="utf-8")
# Parse frontmatter if present
name = path.parent.name if path.parent else "unknown"
description = ""
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if frontmatter_match:
fm_text = frontmatter_match.group(1)
name_match = re.search(r"^name:\s*(.*?)$", fm_text, re.MULTILINE)
if name_match:
name = name_match.group(1).strip()
desc_match = re.search(r"^description:\s*(.*?)$", fm_text, re.MULTILINE)
if desc_match:
description = desc_match.group(1).strip()
# Remove multiline YAML indicators if any
description = re.sub(r"^>-?\s*", "", description)
description = re.sub(r"\n\s+", " ", description)
# Parse headers if no name/description in frontmatter
if name == "unknown" or not description:
title_match = re.search(r"^#\s+(.*?)$", content, re.MULTILINE)
if title_match and name == "unknown":
name = title_match.group(1).strip()
# Remove prefixes like "Skill:" or "SOP:"
name = re.sub(
r"^(Skill|SOP|Playbook):\s*", "", name, flags=re.IGNORECASE
)
# Objectives and triggers
objectives: List[str] = []
triggers: List[str] = []
# Simple regex searches for bullet points
lines = content.splitlines()
in_objectives = False
in_triggers = False
for line in lines:
line_str = line.strip()
if not line_str:
continue
# Section detection
if line_str.startswith("#"):
in_objectives = (
"objective" in line_str.lower() or "ziel" in line_str.lower()
)
in_triggers = any(
x in line_str.lower()
for x in ["trigger", "when to use", "examples", "beispiele"]
)
continue
if in_objectives and (
line_str.startswith("-")
or line_str.startswith("*")
or re.match(r"^\d+\.", line_str)
):
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
objectives.append(clean_line)
if in_triggers and (
line_str.startswith("-")
or line_str.startswith("*")
or re.match(r"^\d+\.", line_str)
):
clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str)
# Filter out generic instructions
if len(clean_line) > 5 and not clean_line.lower().startswith("do not"):
triggers.append(clean_line)
# Fallback if no triggers found
if not triggers:
triggers = [
f"Apply the {name} skill to handle this task.",
f"How do I use {name} here?",
]
return PlaybookMetadata(
name=name,
description=description or f"Playbook for {name}",
objectives=objectives,
trigger_examples=triggers,
file_path=str(path.absolute()),
)
@classmethod
def parse_directory(cls, dir_path: str | Path) -> List[PlaybookMetadata]:
"""Scans a directory for SKILL.md or matches files and parses them."""
dir_path = Path(dir_path)
playbooks: List[PlaybookMetadata] = []
if not dir_path.exists():
return playbooks
# Match any SKILL.md or *_SKILL.md
for path in dir_path.glob("**/SKILL.md"):
pb = cls.parse_file(path)
if pb:
playbooks.append(pb)
for path in dir_path.glob("*_SKILL.md"):
pb = cls.parse_file(path)
if pb:
playbooks.append(pb)
return playbooks
@dataclass
class RepresentationVector:
"""A steering direction vector or set of vectors for representation engineering."""
skill_id: str
# Map from layer index (e.g. 0 to L-1) to difference activation tensor [hidden_dim]
layer_vectors: Dict[int, torch.Tensor] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
class RepresentationVectorExtractor:
"""Extracts representation (steering) vectors from model activations using minimal pairs."""
def __init__(self, model: nn.Module, tokenizer: Any, device: str = "cpu") -> None:
self.model = model
self.tokenizer = tokenizer
self.device = device
def extract_steering_vector(
self,
skill_metadata: PlaybookMetadata,
layers_to_extract: List[int] | None = None,
) -> RepresentationVector:
"""Computes the difference in hidden states for win vs lose prompts."""
self.model.eval()
# Determine layers to hook
# Pythia models store layers in model.gpt_neox.layers
# Let's inspect model layout dynamically
transformer_layers = []
for name, module in self.model.named_modules():
if re.match(r".*layers?\.\d+$", name):
transformer_layers.append((name, module))
if not transformer_layers:
# Fallback/guess for standard PyTorch modules
logger.warning(
"Could not automatically resolve transformer layers, will attempt default Hook paths"
)
num_layers = len(transformer_layers)
if layers_to_extract is None:
# Default: extract from middle/late layers (e.g., last half of the network)
layers_to_extract = list(range(num_layers // 2, num_layers))
# Generate minimal pairs from triggers
win_prompts = []
lose_prompts = []
for trigger in skill_metadata.trigger_examples:
# Win prompt guides the model to invoke the playbook skill/format
win_prompts.append(
f"Instructions: You are acting with the following skill: {skill_metadata.name}. "
f"Description: {skill_metadata.description}\n"
f"Request: {trigger}\n"
f"Output:"
)
# Lose prompt asks the model to respond normally
lose_prompts.append(
f"Instructions: Respond normally.\n" f"Request: {trigger}\n" f"Output:"
)
# Temporary storage for hooked activations
# Map: layer_idx -> list of tensors [seq_len, hidden_dim]
win_activations: Dict[int, List[torch.Tensor]] = {
idx: [] for idx in layers_to_extract
}
lose_activations: Dict[int, List[torch.Tensor]] = {
idx: [] for idx in layers_to_extract
}
# Hook function builder
def make_hook(layer_idx: int, storage: Dict[int, List[torch.Tensor]]) -> Any:
def hook_fn(
module: nn.Module,
input_t: Tuple[torch.Tensor, ...],
output_t: torch.Tensor,
) -> None:
# output_t is typically [batch, seq_len, hidden_dim] or a tuple
if isinstance(output_t, tuple):
output_t = output_t[0]
# Detach and move to CPU to save GPU memory
storage[layer_idx].append(output_t.detach().cpu())
return hook_fn
# Register hooks
hooks = []
for idx in layers_to_extract:
if idx < len(transformer_layers):
_, layer_module = transformer_layers[idx]
h = layer_module.register_forward_hook(make_hook(idx, win_activations))
hooks.append(h)
# Run forward pass for win prompts
for prompt in win_prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
self.model(**inputs)
# Remove hooks and register them for lose activations
for h in hooks:
h.remove()
hooks.clear()
for idx in layers_to_extract:
if idx < len(transformer_layers):
_, layer_module = transformer_layers[idx]
h = layer_module.register_forward_hook(make_hook(idx, lose_activations))
hooks.append(h)
# Run forward pass for lose prompts
for prompt in lose_prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
self.model(**inputs)
# Remove hooks
for h in hooks:
h.remove()
# Compute difference vectors
layer_vectors: Dict[int, torch.Tensor] = {}
for idx in layers_to_extract:
win_tensors = win_activations[idx]
lose_tensors = lose_activations[idx]
if not win_tensors or not lose_tensors:
continue
diffs = []
for win_t, lose_t in zip(win_tensors, lose_tensors):
# win_t and lose_t are [1, seq_len, hidden_dim]
# We can average over sequence length or take the last token (representation at decision point)
# Let's average over the sequence length for stability
w_mean = win_t.mean(dim=1).squeeze(0) # [hidden_dim]
l_mean = lose_t.mean(dim=1).squeeze(0) # [hidden_dim]
diffs.append(w_mean - l_mean)
# Average difference vector across all minimal pairs
mean_diff = torch.stack(diffs).mean(dim=0)
# Normalize vector to unit norm for consistent steering scales
norm = torch.norm(mean_diff)
if norm > 1e-8:
mean_diff = mean_diff / norm
layer_vectors[idx] = mean_diff
metadata = {
"name": skill_metadata.name,
"description": skill_metadata.description,
"trigger_count": len(skill_metadata.trigger_examples),
}
logger.info(
f"Extracted representation vector for '{skill_metadata.name}' | {len(layer_vectors)} layers"
)
return RepresentationVector(
skill_id=skill_metadata.name.lower().replace(" ", "_"),
layer_vectors=layer_vectors,
metadata=metadata,
)
class SkillVectorLibrary:
"""Library containing extracted skill representation vectors."""
def __init__(self) -> None:
self.vectors: Dict[str, RepresentationVector] = {}
def add_vector(self, vec: RepresentationVector) -> None:
self.vectors[vec.skill_id] = vec
def get_vector(self, skill_id: str) -> RepresentationVector | None:
return self.vectors.get(skill_id)
def save(self, path: str | Path) -> None:
"""Saves the library to disk."""
state = {
skill_id: {
"skill_id": vec.skill_id,
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
"metadata": vec.metadata,
}
for skill_id, vec in self.vectors.items()
}
torch.save(state, path)
logger.info(
f"Skill Vector Library saved to {path} ({len(self.vectors)} skills)"
)
def load(self, path: str | Path) -> None:
"""Loads the library from disk."""
self.vectors.clear()
state = torch.load(path, map_location="cpu", weights_only=False)
for skill_id, vec_state in state.items():
self.vectors[skill_id] = RepresentationVector(
skill_id=vec_state["skill_id"],
layer_vectors=vec_state["layer_vectors"],
metadata=vec_state["metadata"],
)
logger.info(
f"Skill Vector Library loaded from {path} ({len(self.vectors)} skills)"
)
class ProcessVectorLibrary:
"""Library containing sequential process step representation vectors."""
def __init__(self) -> None:
self.processes: Dict[str, List[RepresentationVector]] = {}
def add_process(self, process_id: str, steps: List[RepresentationVector]) -> None:
self.processes[process_id] = steps
def get_process_step(
self, process_id: str, step_idx: int
) -> RepresentationVector | None:
steps = self.processes.get(process_id)
if steps and 0 <= step_idx < len(steps):
return steps[step_idx]
return None
def save(self, path: str | Path) -> None:
"""Saves the library to disk."""
state = {
p_id: [
{
"skill_id": vec.skill_id,
"layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()},
"metadata": vec.metadata,
}
for vec in steps
]
for p_id, steps in self.processes.items()
}
torch.save(state, path)
logger.info(
f"Process Vector Library saved to {path} ({len(self.processes)} processes)"
)
def load(self, path: str | Path) -> None:
"""Loads the library from disk."""
self.processes.clear()
state = torch.load(path, map_location="cpu", weights_only=False)
for p_id, steps_state in state.items():
steps = []
for vec_state in steps_state:
steps.append(
RepresentationVector(
skill_id=vec_state["skill_id"],
layer_vectors=vec_state["layer_vectors"],
metadata=vec_state["metadata"],
)
)
self.processes[p_id] = steps
logger.info(
f"Process Vector Library loaded from {path} ({len(self.processes)} processes)"
)