From c6ba37dc398ed730823880ec7de193f37a63050b Mon Sep 17 00:00:00 2001 From: AI-anonymous Date: Sat, 23 May 2026 00:09:51 +0200 Subject: [PATCH] Fix mypy and ruff formatting in representation_engineering.py --- python/representation_engineering.py | 413 +++++++++++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 python/representation_engineering.py diff --git a/python/representation_engineering.py b/python/representation_engineering.py new file mode 100644 index 0000000..9f5d51b --- /dev/null +++ b/python/representation_engineering.py @@ -0,0 +1,413 @@ +"""Representation Engineering and Vector Extraction from Playbooks. + +Parses playbooks (SKILL.md), extracts minimal pairs, and computes steerable +activation vectors or QLoRA adapters to build Skill and Process Vector Libraries. +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import torch +import torch.nn as nn + +logger = logging.getLogger("representation_engineering") +logging.basicConfig(level=logging.INFO) + + +@dataclass +class PlaybookMetadata: + """Metadata and examples parsed from a SKILL.md playbook.""" + + name: str + description: str + objectives: List[str] = field(default_factory=list) + trigger_examples: List[str] = field(default_factory=list) + file_path: str = "" + + +class PlaybookParser: + """Parses SKILL.md files to extract structured metadata and trigger examples.""" + + @staticmethod + def parse_file(path: str | Path) -> PlaybookMetadata | None: + """Parses a single SKILL.md file.""" + path = Path(path) + if not path.exists(): + logger.warning(f"File not found: {path}") + return None + + content = path.read_text(encoding="utf-8") + + # Parse frontmatter if present + name = path.parent.name if path.parent else "unknown" + description = "" + + frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL) + if frontmatter_match: + fm_text = frontmatter_match.group(1) + name_match = re.search(r"^name:\s*(.*?)$", fm_text, re.MULTILINE) + if name_match: + name = name_match.group(1).strip() + desc_match = re.search(r"^description:\s*(.*?)$", fm_text, re.MULTILINE) + if desc_match: + description = desc_match.group(1).strip() + # Remove multiline YAML indicators if any + description = re.sub(r"^>-?\s*", "", description) + description = re.sub(r"\n\s+", " ", description) + + # Parse headers if no name/description in frontmatter + if name == "unknown" or not description: + title_match = re.search(r"^#\s+(.*?)$", content, re.MULTILINE) + if title_match and name == "unknown": + name = title_match.group(1).strip() + # Remove prefixes like "Skill:" or "SOP:" + name = re.sub( + r"^(Skill|SOP|Playbook):\s*", "", name, flags=re.IGNORECASE + ) + + # Objectives and triggers + objectives: List[str] = [] + triggers: List[str] = [] + + # Simple regex searches for bullet points + lines = content.splitlines() + in_objectives = False + in_triggers = False + + for line in lines: + line_str = line.strip() + if not line_str: + continue + + # Section detection + if line_str.startswith("#"): + in_objectives = ( + "objective" in line_str.lower() or "ziel" in line_str.lower() + ) + in_triggers = any( + x in line_str.lower() + for x in ["trigger", "when to use", "examples", "beispiele"] + ) + continue + + if in_objectives and ( + line_str.startswith("-") + or line_str.startswith("*") + or re.match(r"^\d+\.", line_str) + ): + clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str) + objectives.append(clean_line) + + if in_triggers and ( + line_str.startswith("-") + or line_str.startswith("*") + or re.match(r"^\d+\.", line_str) + ): + clean_line = re.sub(r"^[-*\d\.]+\s*", "", line_str) + # Filter out generic instructions + if len(clean_line) > 5 and not clean_line.lower().startswith("do not"): + triggers.append(clean_line) + + # Fallback if no triggers found + if not triggers: + triggers = [ + f"Apply the {name} skill to handle this task.", + f"How do I use {name} here?", + ] + + return PlaybookMetadata( + name=name, + description=description or f"Playbook for {name}", + objectives=objectives, + trigger_examples=triggers, + file_path=str(path.absolute()), + ) + + @classmethod + def parse_directory(cls, dir_path: str | Path) -> List[PlaybookMetadata]: + """Scans a directory for SKILL.md or matches files and parses them.""" + dir_path = Path(dir_path) + playbooks: List[PlaybookMetadata] = [] + if not dir_path.exists(): + return playbooks + + # Match any SKILL.md or *_SKILL.md + for path in dir_path.glob("**/SKILL.md"): + pb = cls.parse_file(path) + if pb: + playbooks.append(pb) + + for path in dir_path.glob("*_SKILL.md"): + pb = cls.parse_file(path) + if pb: + playbooks.append(pb) + + return playbooks + + +@dataclass +class RepresentationVector: + """A steering direction vector or set of vectors for representation engineering.""" + + skill_id: str + # Map from layer index (e.g. 0 to L-1) to difference activation tensor [hidden_dim] + layer_vectors: Dict[int, torch.Tensor] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + +class RepresentationVectorExtractor: + """Extracts representation (steering) vectors from model activations using minimal pairs.""" + + def __init__(self, model: nn.Module, tokenizer: Any, device: str = "cpu") -> None: + self.model = model + self.tokenizer = tokenizer + self.device = device + + def extract_steering_vector( + self, + skill_metadata: PlaybookMetadata, + layers_to_extract: List[int] | None = None, + ) -> RepresentationVector: + """Computes the difference in hidden states for win vs lose prompts.""" + self.model.eval() + + # Determine layers to hook + # Pythia models store layers in model.gpt_neox.layers + # Let's inspect model layout dynamically + transformer_layers = [] + for name, module in self.model.named_modules(): + if re.match(r".*layers?\.\d+$", name): + transformer_layers.append((name, module)) + + if not transformer_layers: + # Fallback/guess for standard PyTorch modules + logger.warning( + "Could not automatically resolve transformer layers, will attempt default Hook paths" + ) + + num_layers = len(transformer_layers) + if layers_to_extract is None: + # Default: extract from middle/late layers (e.g., last half of the network) + layers_to_extract = list(range(num_layers // 2, num_layers)) + + # Generate minimal pairs from triggers + win_prompts = [] + lose_prompts = [] + + for trigger in skill_metadata.trigger_examples: + # Win prompt guides the model to invoke the playbook skill/format + win_prompts.append( + f"Instructions: You are acting with the following skill: {skill_metadata.name}. " + f"Description: {skill_metadata.description}\n" + f"Request: {trigger}\n" + f"Output:" + ) + # Lose prompt asks the model to respond normally + lose_prompts.append( + f"Instructions: Respond normally.\n" f"Request: {trigger}\n" f"Output:" + ) + + # Temporary storage for hooked activations + # Map: layer_idx -> list of tensors [seq_len, hidden_dim] + win_activations: Dict[int, List[torch.Tensor]] = { + idx: [] for idx in layers_to_extract + } + lose_activations: Dict[int, List[torch.Tensor]] = { + idx: [] for idx in layers_to_extract + } + + # Hook function builder + def make_hook(layer_idx: int, storage: Dict[int, List[torch.Tensor]]) -> Any: + def hook_fn( + module: nn.Module, + input_t: Tuple[torch.Tensor, ...], + output_t: torch.Tensor, + ) -> None: + # output_t is typically [batch, seq_len, hidden_dim] or a tuple + if isinstance(output_t, tuple): + output_t = output_t[0] + + # Detach and move to CPU to save GPU memory + storage[layer_idx].append(output_t.detach().cpu()) + + return hook_fn + + # Register hooks + hooks = [] + for idx in layers_to_extract: + if idx < len(transformer_layers): + _, layer_module = transformer_layers[idx] + h = layer_module.register_forward_hook(make_hook(idx, win_activations)) + hooks.append(h) + + # Run forward pass for win prompts + for prompt in win_prompts: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + with torch.no_grad(): + self.model(**inputs) + + # Remove hooks and register them for lose activations + for h in hooks: + h.remove() + hooks.clear() + + for idx in layers_to_extract: + if idx < len(transformer_layers): + _, layer_module = transformer_layers[idx] + h = layer_module.register_forward_hook(make_hook(idx, lose_activations)) + hooks.append(h) + + # Run forward pass for lose prompts + for prompt in lose_prompts: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + with torch.no_grad(): + self.model(**inputs) + + # Remove hooks + for h in hooks: + h.remove() + + # Compute difference vectors + layer_vectors: Dict[int, torch.Tensor] = {} + for idx in layers_to_extract: + win_tensors = win_activations[idx] + lose_tensors = lose_activations[idx] + + if not win_tensors or not lose_tensors: + continue + + diffs = [] + for win_t, lose_t in zip(win_tensors, lose_tensors): + # win_t and lose_t are [1, seq_len, hidden_dim] + # We can average over sequence length or take the last token (representation at decision point) + # Let's average over the sequence length for stability + w_mean = win_t.mean(dim=1).squeeze(0) # [hidden_dim] + l_mean = lose_t.mean(dim=1).squeeze(0) # [hidden_dim] + diffs.append(w_mean - l_mean) + + # Average difference vector across all minimal pairs + mean_diff = torch.stack(diffs).mean(dim=0) + + # Normalize vector to unit norm for consistent steering scales + norm = torch.norm(mean_diff) + if norm > 1e-8: + mean_diff = mean_diff / norm + + layer_vectors[idx] = mean_diff + + metadata = { + "name": skill_metadata.name, + "description": skill_metadata.description, + "trigger_count": len(skill_metadata.trigger_examples), + } + + logger.info( + f"Extracted representation vector for '{skill_metadata.name}' | {len(layer_vectors)} layers" + ) + return RepresentationVector( + skill_id=skill_metadata.name.lower().replace(" ", "_"), + layer_vectors=layer_vectors, + metadata=metadata, + ) + + +class SkillVectorLibrary: + """Library containing extracted skill representation vectors.""" + + def __init__(self) -> None: + self.vectors: Dict[str, RepresentationVector] = {} + + def add_vector(self, vec: RepresentationVector) -> None: + self.vectors[vec.skill_id] = vec + + def get_vector(self, skill_id: str) -> RepresentationVector | None: + return self.vectors.get(skill_id) + + def save(self, path: str | Path) -> None: + """Saves the library to disk.""" + state = { + skill_id: { + "skill_id": vec.skill_id, + "layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()}, + "metadata": vec.metadata, + } + for skill_id, vec in self.vectors.items() + } + torch.save(state, path) + logger.info( + f"Skill Vector Library saved to {path} ({len(self.vectors)} skills)" + ) + + def load(self, path: str | Path) -> None: + """Loads the library from disk.""" + self.vectors.clear() + state = torch.load(path, map_location="cpu", weights_only=False) + for skill_id, vec_state in state.items(): + self.vectors[skill_id] = RepresentationVector( + skill_id=vec_state["skill_id"], + layer_vectors=vec_state["layer_vectors"], + metadata=vec_state["metadata"], + ) + logger.info( + f"Skill Vector Library loaded from {path} ({len(self.vectors)} skills)" + ) + + +class ProcessVectorLibrary: + """Library containing sequential process step representation vectors.""" + + def __init__(self) -> None: + self.processes: Dict[str, List[RepresentationVector]] = {} + + def add_process(self, process_id: str, steps: List[RepresentationVector]) -> None: + self.processes[process_id] = steps + + def get_process_step( + self, process_id: str, step_idx: int + ) -> RepresentationVector | None: + steps = self.processes.get(process_id) + if steps and 0 <= step_idx < len(steps): + return steps[step_idx] + return None + + def save(self, path: str | Path) -> None: + """Saves the library to disk.""" + state = { + p_id: [ + { + "skill_id": vec.skill_id, + "layer_vectors": {k: v.cpu() for k, v in vec.layer_vectors.items()}, + "metadata": vec.metadata, + } + for vec in steps + ] + for p_id, steps in self.processes.items() + } + torch.save(state, path) + logger.info( + f"Process Vector Library saved to {path} ({len(self.processes)} processes)" + ) + + def load(self, path: str | Path) -> None: + """Loads the library from disk.""" + self.processes.clear() + state = torch.load(path, map_location="cpu", weights_only=False) + for p_id, steps_state in state.items(): + steps = [] + for vec_state in steps_state: + steps.append( + RepresentationVector( + skill_id=vec_state["skill_id"], + layer_vectors=vec_state["layer_vectors"], + metadata=vec_state["metadata"], + ) + ) + self.processes[p_id] = steps + logger.info( + f"Process Vector Library loaded from {path} ({len(self.processes)} processes)" + )