feat: Add playbook vector extraction and activation steering routing to FCES training pipeline
This commit is contained in:
251
tests/test_representation_engineering.py
Normal file
251
tests/test_representation_engineering.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Unit tests for Representation Engineering and Vector Library Compilation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Ensure python directory is in path
|
||||
sys.path.append(
|
||||
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "python")
|
||||
)
|
||||
|
||||
from representation_engineering import (
|
||||
PlaybookParser,
|
||||
PlaybookMetadata,
|
||||
RepresentationVector,
|
||||
RepresentationVectorExtractor,
|
||||
SkillVectorLibrary,
|
||||
ProcessVectorLibrary,
|
||||
)
|
||||
from adapter_moe_router import ExpertAdapterRouter
|
||||
|
||||
|
||||
# --- Mock Classes for Testing ---
|
||||
|
||||
|
||||
class SimpleTransformerLayer(nn.Module): # type: ignore[misc]
|
||||
def __init__(self, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
||||
with torch.no_grad():
|
||||
self.linear.weight.fill_(0.1)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> Tuple[torch.Tensor]:
|
||||
return (x + self.linear(x),)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module): # type: ignore[misc]
|
||||
def __init__(self, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
# We name layers matching .layers.\d+$ to verify registration
|
||||
self.layers = nn.ModuleList(
|
||||
[SimpleTransformerLayer(hidden_dim), SimpleTransformerLayer(hidden_dim)]
|
||||
)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
|
||||
# Simple embedding emulation: shape [batch, seq_len, hidden_dim]
|
||||
# input_ids shape: [batch, seq_len]
|
||||
batch_size, seq_len = input_ids.shape
|
||||
x = (
|
||||
torch.ones(batch_size, seq_len, self.hidden_dim, dtype=torch.float32)
|
||||
* input_ids.unsqueeze(-1).float()
|
||||
)
|
||||
for layer in self.layers:
|
||||
x = layer(x)[0]
|
||||
|
||||
class Output:
|
||||
def __init__(self, logits: torch.Tensor) -> None:
|
||||
self.logits = logits
|
||||
|
||||
return Output(logits=x)
|
||||
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __call__(
|
||||
self, text: str | List[str], return_tensors: str = "pt", **kwargs: Any
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
val = 2 if any("skill" in t.lower() for t in text) else 1
|
||||
return {"input_ids": torch.ones(len(text), 5, dtype=torch.long) * val}
|
||||
|
||||
|
||||
# --- Test Cases ---
|
||||
|
||||
|
||||
class TestRepresentationEngineering(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model = SimpleModel(hidden_dim=16)
|
||||
self.tokenizer = SimpleTokenizer()
|
||||
self.extractor = RepresentationVectorExtractor(
|
||||
self.model, self.tokenizer, device="cpu"
|
||||
)
|
||||
|
||||
def test_playbook_parser(self) -> None:
|
||||
# Create a mock SKILL.md content
|
||||
mock_content = """---
|
||||
name: Mock Test Skill
|
||||
description: >-
|
||||
This is a description
|
||||
with multiline text.
|
||||
---
|
||||
# Skill: Mock Test Skill
|
||||
|
||||
## Objective
|
||||
- Analyze inputs dynamically.
|
||||
- Perform steering optimization.
|
||||
|
||||
## Activation Trigger
|
||||
- Use this when testing the parser.
|
||||
- Trigger example two.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
file_path = Path(tmpdir) / "SKILL.md"
|
||||
file_path.write_text(mock_content, encoding="utf-8")
|
||||
|
||||
pb = PlaybookParser.parse_file(file_path)
|
||||
assert pb is not None
|
||||
self.assertEqual(pb.name, "Mock Test Skill")
|
||||
self.assertEqual(
|
||||
pb.description, "This is a description with multiline text."
|
||||
)
|
||||
self.assertIn("Analyze inputs dynamically.", pb.objectives)
|
||||
self.assertIn("Use this when testing the parser.", pb.trigger_examples)
|
||||
self.assertIn("Trigger example two.", pb.trigger_examples)
|
||||
|
||||
def test_vector_extraction(self) -> None:
|
||||
mock_pb = PlaybookMetadata(
|
||||
name="TestSteering",
|
||||
description="Steering description",
|
||||
objectives=["Objective 1"],
|
||||
trigger_examples=["Trigger 1"],
|
||||
file_path="mock_path.md",
|
||||
)
|
||||
# Extract from layer index 1
|
||||
vec = self.extractor.extract_steering_vector(mock_pb, layers_to_extract=[1])
|
||||
self.assertEqual(vec.skill_id, "teststeering")
|
||||
self.assertIn(1, vec.layer_vectors)
|
||||
self.assertEqual(vec.layer_vectors[1].shape, (16,))
|
||||
# Assert normalized to unit norm
|
||||
self.assertAlmostEqual(torch.norm(vec.layer_vectors[1]).item(), 1.0, places=5)
|
||||
|
||||
def test_libraries_save_and_load(self) -> None:
|
||||
skill_lib = SkillVectorLibrary()
|
||||
process_lib = ProcessVectorLibrary()
|
||||
|
||||
# Add mock vectors
|
||||
v1 = RepresentationVector(
|
||||
skill_id="skill_a",
|
||||
layer_vectors={0: torch.ones(16), 1: torch.zeros(16)},
|
||||
metadata={"name": "Skill A"},
|
||||
)
|
||||
skill_lib.add_vector(v1)
|
||||
|
||||
v2 = RepresentationVector(
|
||||
skill_id="step_1",
|
||||
layer_vectors={0: torch.ones(16) * 0.5},
|
||||
metadata={"name": "Step 1"},
|
||||
)
|
||||
process_lib.add_process("process_a", [v2])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
skill_path = Path(tmpdir) / "skills.pt"
|
||||
proc_path = Path(tmpdir) / "processes.pt"
|
||||
|
||||
skill_lib.save(skill_path)
|
||||
process_lib.save(proc_path)
|
||||
|
||||
self.assertTrue(skill_path.exists())
|
||||
self.assertTrue(proc_path.exists())
|
||||
|
||||
# Load into new libraries
|
||||
new_skill_lib = SkillVectorLibrary()
|
||||
new_proc_lib = ProcessVectorLibrary()
|
||||
|
||||
new_skill_lib.load(skill_path)
|
||||
new_proc_lib.load(proc_path)
|
||||
|
||||
loaded_v1 = new_skill_lib.get_vector("skill_a")
|
||||
assert loaded_v1 is not None
|
||||
self.assertEqual(loaded_v1.skill_id, "skill_a")
|
||||
self.assertTrue(torch.allclose(loaded_v1.layer_vectors[0], torch.ones(16)))
|
||||
|
||||
loaded_v2 = new_proc_lib.get_process_step("process_a", 0)
|
||||
assert loaded_v2 is not None
|
||||
self.assertEqual(loaded_v2.skill_id, "step_1")
|
||||
self.assertTrue(
|
||||
torch.allclose(loaded_v2.layer_vectors[0], torch.ones(16) * 0.5)
|
||||
)
|
||||
|
||||
def test_router_activation_steering(self) -> None:
|
||||
# Create a skill library and add a steering vector
|
||||
skill_lib = SkillVectorLibrary()
|
||||
v = RepresentationVector(
|
||||
skill_id="skill_steer",
|
||||
layer_vectors={0: torch.ones(16) * 0.5, 1: torch.ones(16) * -0.5},
|
||||
metadata={"name": "Steer"},
|
||||
)
|
||||
skill_lib.add_vector(v)
|
||||
|
||||
# Setup router
|
||||
router = ExpertAdapterRouter(
|
||||
base_model=self.model,
|
||||
skill_library=skill_lib,
|
||||
in_features=16,
|
||||
steering_alpha=2.0,
|
||||
steering_mode="prompt",
|
||||
)
|
||||
|
||||
# Test 1: Forward without hooks (baseline)
|
||||
x = torch.ones(1, 4, dtype=torch.long) # batch=1, seq_len=4
|
||||
with torch.no_grad():
|
||||
out_base = self.model(x).logits
|
||||
|
||||
# Test 2: Register hooks and set active routing to skill_steer (index 0)
|
||||
router.register_hooks()
|
||||
self.assertEqual(len(router.hooks), 2) # Should hook both layers
|
||||
|
||||
priors = torch.tensor([1.0])
|
||||
router.set_active_routing(priors)
|
||||
|
||||
with torch.no_grad():
|
||||
out_steered = self.model(x).logits
|
||||
|
||||
# Verify that steering changed the output
|
||||
self.assertFalse(torch.allclose(out_base, out_steered))
|
||||
|
||||
# Test 3: Process sequential routing
|
||||
router.unregister_hooks()
|
||||
proc_lib = ProcessVectorLibrary()
|
||||
v_proc = RepresentationVector(
|
||||
skill_id="step_a",
|
||||
layer_vectors={0: torch.ones(16) * 10.0}, # large steering value
|
||||
metadata={"name": "Step A"},
|
||||
)
|
||||
proc_lib.add_process("my_workflow", [v_proc])
|
||||
|
||||
router.process_library = proc_lib
|
||||
router.active_process_id = "my_workflow"
|
||||
router.active_process_step = 0
|
||||
router.register_hooks()
|
||||
|
||||
with torch.no_grad():
|
||||
out_proc = self.model(x).logits
|
||||
|
||||
self.assertFalse(torch.allclose(out_base, out_proc))
|
||||
router.unregister_hooks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user