feat: Add playbook vector extraction and activation steering routing to FCES training pipeline

This commit is contained in:
AI-anonymous
2026-05-23 08:33:27 +02:00
parent e0d8a32823
commit 4c9a550f8b
4 changed files with 884 additions and 46 deletions

View File

@@ -8,11 +8,12 @@ learnable token-level gates.
from __future__ import annotations
import re
from typing import Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from parasitic_qlora import ExpertAdapter
from representation_engineering import SkillVectorLibrary, ProcessVectorLibrary
class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
@@ -41,25 +42,53 @@ class LearnableGate(nn.Module): # type: ignore[misc, unused-ignore]
class ExpertAdapterRouter:
"""Manages dynamic MoE-style routing over a library of LoRA adapters."""
"""Manages dynamic MoE-style routing over a library of LoRA adapters and representation vectors."""
def __init__(
self,
base_model: nn.Module,
adapter_library: List[ExpertAdapter],
adapter_library: Optional[List[ExpertAdapter]] = None,
in_features: int = 768, # Match model hidden dim (e.g. Pythia-70m)
skill_library: Optional[SkillVectorLibrary] = None,
process_library: Optional[ProcessVectorLibrary] = None,
steering_alpha: float = 1.0,
steering_mode: str = "token", # "token" or "prompt"
) -> None:
self.base_model = base_model
self.adapter_library = adapter_library
self.num_adapters = len(adapter_library)
self.adapter_library = adapter_library or []
self.num_adapters = len(self.adapter_library)
self.skill_library = skill_library
self.process_library = process_library
self.steering_alpha = steering_alpha
self.steering_mode = steering_mode
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
self.active_process_id: Optional[str] = None
self.active_process_step: Optional[int] = None
# Learnable gating network
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
# Sorted list of skill IDs for index-based routing
self.skill_ids = (
sorted(list(self.skill_library.vectors.keys()))
if self.skill_library
else []
)
# Active weights for current forward pass (batch size × num_adapters)
# Learnable gating network for adapters
if self.num_adapters > 0:
self.gate = LearnableGate(in_features, self.num_adapters).to(
next(base_model.parameters()).device
)
else:
self.gate = None
# Learnable gating network for skills
if len(self.skill_ids) > 0:
self.skill_gate = LearnableGate(in_features, len(self.skill_ids)).to(
next(base_model.parameters()).device
)
else:
self.skill_gate = None
# Active weights for current forward pass (batch size × num_adapters/skills)
self.current_gate_weights: Optional[torch.Tensor] = None
def compute_fuzzy_priors(self, text: str) -> torch.Tensor:
@@ -97,30 +126,46 @@ class ExpertAdapterRouter:
self.current_gate_weights = fuzzy_priors
def register_hooks(self) -> None:
"""Attaches forward hooks to linear layers present in the adapter library."""
"""Attaches forward hooks to linear layers (adapters) and transformer blocks (steering)."""
self.unregister_hooks()
# Find all layers in the base model that have adapters
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
# 1. Bind adapter hooks if adapters are present
if self.num_adapters > 0:
adapter_layers: set[str] = set()
for adapter in self.adapter_library:
adapter_layers.update(adapter.layers.keys())
# Bind hooks dynamically
for name, module in self.base_model.named_modules():
# Check if this specific module has an adapter
# We match using suffix to support model wrapping/prefixes
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
for name, module in self.base_model.named_modules():
matching_adapter_name = None
for layer_name in adapter_layers:
clean_layer_name = layer_name.replace(".weight", "").replace(
".bias", ""
)
if name.endswith(clean_layer_name) or name == clean_layer_name:
matching_adapter_name = layer_name
break
if matching_adapter_name and isinstance(module, nn.Linear):
if matching_adapter_name and isinstance(module, nn.Linear):
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
)
self.hooks.append(hook)
# 2. Bind steering hooks if skill_library or process_library is present
if self.skill_library or self.process_library:
transformer_layers = []
for name, module in self.base_model.named_modules():
match = re.match(r".*layers?\.(\d+)$", name)
if match:
layer_idx = int(match.group(1))
transformer_layers.append((layer_idx, name, module))
# Sort by layer_idx to ensure consistent mapping
transformer_layers.sort(key=lambda x: x[0])
for layer_idx, name, module in transformer_layers:
hook = module.register_forward_hook(
self._make_hook_fn(matching_adapter_name)
self._make_steering_hook_fn(layer_idx)
)
self.hooks.append(hook)
@@ -152,38 +197,100 @@ class ExpertAdapterRouter:
)
else:
# Compute dynamically per token via learnable gate
# We pool over sequence length or route per token
# Let's route token-wise: gate_logits has shape [batch, seq_len, num_adapters]
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
if self.gate is not None:
gate_logits = self.gate(x)
weights = torch.softmax(gate_logits, dim=-1)
else:
return output_tensor
# Compute combined low-rank contribution
# Y_lora = sum_i g_i * (x @ A_i.t()) @ B_i.t()
adapter_output = torch.zeros_like(output_tensor)
for i, adapter in enumerate(self.adapter_library):
if layer_name in adapter.layers:
lm = adapter.layers[layer_name]
# Ensure tensors are on the correct device
lora_A = lm.lora_A.to(device=x.device, dtype=x.dtype)
lora_B = lm.lora_B.to(device=x.device, dtype=x.dtype)
# Dynamic scaling: gate_weight for this adapter
# weights has shape [batch, num_adapters] or [batch, seq_len, num_adapters]
if len(weights.shape) == 3:
# Token-level routing: shape [batch, seq_len, 1]
g = weights[..., i : i + 1]
else:
# Batch-level routing: shape [batch, 1, 1]
g = weights[:, i].view(-1, 1, 1)
# Low-rank projection
x_proj = torch.matmul(x, lora_A.t())
y_proj = torch.matmul(x_proj, lora_B.t())
# Accumulate scaled delta
adapter_output += g * y_proj
return output_tensor + adapter_output
return hook_fn
def _make_steering_hook_fn(self, layer_idx: int) -> Callable[..., Any]:
"""Creates a hook function to inject activation steering vectors at a specific layer."""
def hook_fn(
module: nn.Module,
input_tensor: Tuple[torch.Tensor, ...],
output_tensor: Any,
) -> Any:
is_tuple = isinstance(output_tensor, tuple)
x = output_tensor[0] if is_tuple else output_tensor
# Sequential process/workflow steering
if self.active_process_id is not None and self.process_library is not None:
step_idx = self.active_process_step or 0
step_vector = self.process_library.get_process_step(
self.active_process_id, step_idx
)
if step_vector and layer_idx in step_vector.layer_vectors:
v = step_vector.layer_vectors[layer_idx].to(
device=x.device, dtype=x.dtype
)
steered_x = x + self.steering_alpha * v
if is_tuple:
return (steered_x,) + output_tensor[1:]
return steered_x
return output_tensor
# Dynamic skill routing
if self.skill_library and len(self.skill_ids) > 0:
weights = None
if self.current_gate_weights is not None:
batch_size = x.shape[0]
weights = (
self.current_gate_weights.to(x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
elif self.skill_gate is not None:
if self.steering_mode == "token":
gate_logits = self.skill_gate(x)
weights = torch.softmax(gate_logits, dim=-1)
else:
x_mean = x.mean(dim=1) if len(x.shape) == 3 else x
gate_logits = self.skill_gate(x_mean)
weights = torch.softmax(gate_logits, dim=-1)
if weights is not None:
steer_contribution = torch.zeros_like(x)
for i, skill_id in enumerate(self.skill_ids):
vec = self.skill_library.get_vector(skill_id)
if vec and layer_idx in vec.layer_vectors:
v = vec.layer_vectors[layer_idx].to(
device=x.device, dtype=x.dtype
)
if len(weights.shape) == 3:
g = weights[..., i : i + 1]
else:
g = weights[:, i].view(-1, 1, 1)
steer_contribution += g * v
steered_x = x + self.steering_alpha * steer_contribution
if is_tuple:
return (steered_x,) + output_tensor[1:]
return steered_x
return output_tensor
return hook_fn