"""Representation Engineering and Vector Extraction from Playbooks. Parses playbooks (SKILL.md), extracts minimal pairs, and computes steerable activation vectors or QLoRA adapters to build Skill and Process Vector Libraries. """ from __future__ import annotations import logging import re from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Tuple import torch import torch.nn as nn logger = logging.getLogger("representation_engineering") logging.basicConfig(level=logging.INFO) @dataclass class PlaybookMetadata: """Metadata and examples parsed from a SKILL.md playbook.""" name: str description: str objectives: List[str] = field(default_factory=list) trigger_examples: List[str] = field(default_factory=list) file_path: str = "" class PlaybookParser: """Parses SKILL.md files to extract structured metadata and trigger examples.""" @staticmethod def parse_file(path: str | Path) -> PlaybookMetadata | None: """Parses a single SKILL.md file.""" path = Path(path) if not path.exists(): logger.warning(f"File not found: {path}") return None content = path.read_text(encoding="utf-8") # Parse frontmatter if present name = path.parent.name if path.parent else "unknown" description = "" frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL) if frontmatter_match: fm_text = frontmatter_match.group(1) name_match = re.search(r"^name:\s*(.*?)$", fm_text, re.MULTILINE) if name_match: name = name_match.group(1).strip() desc_match = re.search(r"^description:\s*(.*?)$", fm_text, re.MULTILINE) if desc_match: description = desc_match.group(1).strip() # Remove multiline YAML indicators if any description = re.sub(r"^>-?\s*", "", description) description = re.sub(r"\n\s+", " ", description) # Parse headers if no name/description in frontmatter if name == "unknown" or not description: title_match = re.search(r"^#\s+(.*?)$", content, re.MULTILINE) if title_match and name == "unknown": name = title_match.group(1).strip() # Remove prefixes like "Skill:" or "SOP:" name = re.sub( r"^(Skill|SOP|Playbook):\s*", "", name, flags=re.IGNORECASE ) # Objectives and triggers objectives: List[str] = [] triggers: List[str] = [] # Simple regex searches for bullet points lines = content.splitlines() in_objectives = False in_triggers = False for line in lines: line_str = line.strip() if not line_str: continue # Section detection if line_str.startswith("#"): in_objectives = ( "objective" in line_str.lower() or "ziel" in line_str.lower() ) in_triggers = any( x in line_str.lower() for x in ["trigger", "when to use", "examples", "beispiele"] ) continue if in_objectives and ( line_str.startswith("-") or line_str.startswith("*") or re.match(r"^\d+\.", line_str) ): clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str) objectives.append(clean_line) if in_triggers and ( line_str.startswith("-") or line_str.startswith("*") or re.match(r"^\d+\.", line_str) ): clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str) # Filter out generic instructions if len(clean_line) > 5 and not clean_line.lower().startswith("do not"): triggers.append(clean_line) # Fallback if no triggers found if not triggers: triggers = [ f"Apply the {name} skill to handle this task.", f"How do I use {name} here?", ] return PlaybookMetadata( name=name, description=description or f"Playbook for {name}", objectives=objectives, trigger_examples=triggers, file_path=str(path.absolute()), ) @classmethod def parse_directory(cls, dir_path: str | Path) -> List[PlaybookMetadata]: """Scans a directory for SKILL.md or matches files and parses them.""" dir_path = Path(dir_path) playbooks: List[PlaybookMetadata] = [] if not dir_path.exists(): return playbooks # Match any SKILL.md or *_SKILL.md for path in dir_path.glob("**/SKILL.md"): pb = cls.parse_file(path) if pb: playbooks.append(pb) for path in dir_path.glob("*_SKILL.md"): pb = cls.parse_file(path) if pb: playbooks.append(pb) return playbooks @dataclass class RepresentationVector: """A steering direction vector or set of vectors for representation engineering.""" skill_id: str # Map from layer index (e.g. 0 to L-1) to difference activation tensor [hidden_dim] layer_vectors: Dict[int, torch.Tensor] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) class RepresentationVectorExtractor: """Extracts representation (steering) vectors from model activations using minimal pairs.""" def __init__(self, model: nn.Module, tokenizer: Any, device: str = "cpu") -> None: self.model = model self.tokenizer = tokenizer self.device = device def extract_steering_vector( self, skill_metadata: PlaybookMetadata, layers_to_extract: List[int] | None = None, ) -> RepresentationVector: """Computes the difference in hidden states for win vs lose prompts.""" self.model.eval() # Determine layers to hook # Pythia models store layers in model.gpt_neox.layers # Let's inspect model layout dynamically transformer_layers = [] for name, module in self.model.named_modules(): if re.match(r".*layers?\.\d+$", name): transformer_layers.append((name, module)) if not transformer_layers: # Fallback/guess for standard PyTorch modules logger.warning( "Could not automatically resolve transformer layers, will attempt default Hook paths" ) num_layers = len(transformer_layers) if layers_to_extract is None: # Default: extract from middle/late layers (e.g., last half of the network) layers_to_extract = list(range(num_layers // 2, num_layers)) # Generate minimal pairs from triggers win_prompts = [] lose_prompts = [] for trigger in skill_metadata.trigger_examples: # Win prompt guides the model to invoke the playbook skill/format win_prompts.append( f"Instructions: You are acting with the following skill: {skill_metadata.name}. " f"Description: {skill_metadata.description}\n" f"Request: {trigger}\n" f"Output:" ) # Lose prompt asks the model to respond normally lose_prompts.append( f"Instructions: Respond normally.\n" f"Request: {trigger}\n" f"Output:" ) # Temporary storage for hooked activations # Map: layer_idx -> list of tensors [seq_len, hidden_dim] win_activations: Dict[int, List[torch.Tensor]] = { idx: [] for idx in layers_to_extract } lose_activations: Dict[int, List[torch.Tensor]] = { idx: [] for idx in layers_to_extract } # Hook function builder def make_hook(layer_idx: int, storage: Dict[int, List[torch.Tensor]]) -> Any: def hook_fn( module: nn.Module, input_t: Tuple[torch.Tensor, ...], output_t: torch.Tensor, ) -> None: # output_t is typically [batch, seq_len, hidden_dim] or a tuple if isinstance(output_t, tuple): output_t = output_t[0] # Detach and move to CPU to save GPU memory storage[layer_idx].append(output_t.detach().cpu()) return hook_fn # Register hooks hooks = [] for idx in layers_to_extract: if idx < len(transformer_layers): _, layer_module = transformer_layers[idx] h = layer_module.register_forward_hook(make_hook(idx, win_activations)) hooks.append(h) # Run forward pass for win prompts for prompt in win_prompts: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): self.model(**inputs) # Remove hooks and register them for lose activations for h in hooks: h.remove() hooks.clear() for idx in layers_to_extract: if idx < len(transformer_layers): _, layer_module = transformer_layers[idx] h = layer_module.register_forward_hook(make_hook(idx, lose_activations)) hooks.append(h) # Run forward pass for lose prompts for prompt in lose_prompts: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): self.model(**inputs) # Remove hooks for h in hooks: h.remove() # Compute difference vectors layer_vectors: Dict[int, torch.Tensor] = {} for idx in layers_to_extract: win_tensors = win_activations[idx] lose_tensors = lose_activations[idx] if not win_tensors or not lose_tensors: continue diffs = [] for win_t, lose_t in zip(win_tensors, lose_tensors): # win_t and lose_t are [1, seq_len, hidden_dim] # We can average over sequence length or take the last token (representation at decision point) # Let's average over the sequence length for stability w_mean = win_t.mean(dim=1).squeeze(0) # [hidden_dim] l_mean = lose_t.mean(dim=1).squeeze(0) # [hidden_dim] diffs.append(w_mean - l_mean) # Average difference vector across all minimal pairs mean_diff = torch.stack(diffs).mean(dim=0) # Normalize vector to unit norm for consistent steering scales norm = torch.norm(mean_diff) if norm > 1e-8: mean_diff = mean_diff / norm layer_vectors[idx] = mean_diff metadata = { "name": skill_metadata.name, "description": skill_metadata.description, "trigger_count": len(skill_metadata.trigger_examples), } logger.info( f"Extracted representation vector for '{skill_metadata.name}' | {len(layer_vectors)} layers" ) return RepresentationVector( skill_id=skill_metadata.name.lower().replace(" ", "_"), layer_vectors=layer_vectors, metadata=metadata, ) class SkillVectorLibrary: """Library containing extracted skill representation vectors.""" def __init__(self) -> None: self.vectors: Dict[str, RepresentationVector] = {} def add_vector(self, vec: RepresentationVector) -> None: self.vectors[vec.skill_id] = vec def get_vector(self, skill_id: str) -> RepresentationVector | None: return self.vectors.get(skill_id) def save(self, path: str | Path) -> None: """Saves the library to disk.""" state = { skill_id: { "skill_id": vec.skill_id, "layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()}, "metadata": vec.metadata, } for skill_id, vec in self.vectors.items() } torch.save(state, path) logger.info( f"Skill Vector Library saved to {path} ({len(self.vectors)} skills)" ) def load(self, path: str | Path) -> None: """Loads the library from disk.""" self.vectors.clear() state = torch.load(path, map_location="cpu", weights_only=False) for skill_id, vec_state in state.items(): self.vectors[skill_id] = RepresentationVector( skill_id=vec_state["skill_id"], layer_vectors=vec_state["layer_vectors"], metadata=vec_state["metadata"], ) logger.info( f"Skill Vector Library loaded from {path} ({len(self.vectors)} skills)" ) class ProcessVectorLibrary: """Library containing sequential process step representation vectors.""" def __init__(self) -> None: self.processes: Dict[str, List[RepresentationVector]] = {} def add_process(self, process_id: str, steps: List[RepresentationVector]) -> None: self.processes[process_id] = steps def get_process_step( self, process_id: str, step_idx: int ) -> RepresentationVector | None: steps = self.processes.get(process_id) if steps and 0 <= step_idx < len(steps): return steps[step_idx] return None def save(self, path: str | Path) -> None: """Saves the library to disk.""" state = { p_id: [ { "skill_id": vec.skill_id, "layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()}, "metadata": vec.metadata, } for vec in steps ] for p_id, steps in self.processes.items() } torch.save(state, path) logger.info( f"Process Vector Library saved to {path} ({len(self.processes)} processes)" ) def load(self, path: str | Path) -> None: """Loads the library from disk.""" self.processes.clear() state = torch.load(path, map_location="cpu", weights_only=False) for p_id, steps_state in state.items(): steps = [] for vec_state in steps_state: steps.append( RepresentationVector( skill_id=vec_state["skill_id"], layer_vectors=vec_state["layer_vectors"], metadata=vec_state["metadata"], ) ) self.processes[p_id] = steps logger.info( f"Process Vector Library loaded from {path} ({len(self.processes)} processes)" )