Context & Objective
I am implementing a research experiment on Linearized Control Barrier Functions (LCBF) to reduce hallucinations in VLM generation (specifically llava-hf/llava-1.5-7b-hf).
-
Goal: Implement a “Steering Controller” that intervenes during the forward pass of specific transformer layers (e.g., Layer 20).
-
Mechanism:
-
Intercept the Query (Q) and Key (K) tensors after projection and RoPE, but before the attention dot product.
-
Calculate a “Safety Value” h(Q) defined as the attention mass on visual tokens (Top-K heads).
-
If h(Q) < \\tau (unsafe drift), calculate \\nabla_Q h(Q) and apply a closed-form steering vector \\theta^\* to Q.
-
Resume the forward pass with the steered Q\_{new}.
-
The Implementation Strategy
To achieve this without retraining, I am monkey-patching the forward method of LlamaAttention modules. My wrapper function attempts to replicate the standard attention logic (Projection \\rightarrow RoPE \\rightarrow Cache \\rightarrow Dot Product) but inserts my steering logic right before the dot product.
The Problem
I am encountering persistent AttributeErrors related to accessing internal modules of LlamaAttention, specifically rotary_emb. It appears the attributes exposed by the self object inside the patched method do not match the expected LlamaAttention definition.
I previously faced similar errors for self.num_heads and self.hidden_size, which I resolved by deriving dimensions directly from the input tensors. However, rotary_emb is a generic module (likely LlamaRotaryEmbedding), and I cannot derive it from tensors.
File “experiment2_drift.py”, line 149, in forward
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
AttributeError: ‘LlamaAttention’ object has no attribute ‘rotary_emb’
My Code Below is the reproduction script. It includes the custom_forward_wrapper where I attempt to recreate the forward pass logic.
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch.nn as nn
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import types
import math
import sys
# --- Robust Imports for Llama Utilities ---
try:
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
except ImportError:
# Fallback for newer transformers versions where path might differ
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
# --- Setup Paths ---
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PROJECT_DIR)
from data_setup import load_llava_model
# --- Configuration ---
STEERING_CONFIG = {
"tau": 0.20, # Safety Threshold
"alpha": 1.0, # Step size
"epsilon": 1e-6, # Stability constant
"top_k_heads": 5, # Dynamic Top-K aggregation
"steer_layers": [20], # Steer ONLY Layer 20
"is_active": False, # Global Toggle
"img_start": 0, # Dynamic
"img_end": 0 # Dynamic
}
# --- 1. The LCBF Steering Logic ---
def compute_barrier_value(attn_weights, img_start, img_end, top_k):
"""
Calculates h(q): Mean Attention Mass of the Top-K visual heads.
attn_weights shape: (bsz, num_heads, q_len, k_len)
"""
# Extract attention on visual tokens for the last query step
# Shape: (num_heads, visual_tokens)
# We use batch index 0 (assuming bsz=1)
visual_attn = attn_weights[0, :, -1, img_start:img_end]
# Sum across visual tokens -> Mass per Head
mass_per_head = visual_attn.sum(dim=-1)
# Select Top-K heads
k = min(top_k, mass_per_head.shape[0])
top_masses, _ = torch.topk(mass_per_head, k=k)
# Barrier value is the mean of these Top-K heads
return top_masses.mean()
def apply_lcbf_steering(query_states, key_states, attention_mask, layer_idx):
"""
The Intervention Loop:
1. Observe Q
2. Check h(Q) < tau
3. If unsafe, compute grad and steer Q -> Q_new
"""
tau = STEERING_CONFIG["tau"]
alpha = STEERING_CONFIG["alpha"]
eps = STEERING_CONFIG["epsilon"]
img_start = STEERING_CONFIG["img_start"]
img_end = STEERING_CONFIG["img_end"]
top_k = STEERING_CONFIG["top_k_heads"]
# Step A: Observation
q_temp = query_states.clone().detach().requires_grad_(True)
# Step B: Safety Check
head_dim = q_temp.shape[-1]
k_transposed = key_states.transpose(-1, -2)
attn_scores = torch.matmul(q_temp, k_transposed) / math.sqrt(head_dim)
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
h_val = compute_barrier_value(attn_weights, img_start, img_end, top_k)
current_h = h_val.item()
if current_h >= tau:
return query_states, current_h
# Step C: Gradient
grads = torch.autograd.grad(h_val, q_temp, retain_graph=False)[0]
# Step D: Steering Update (LiSeCo)
grad_norm_sq = torch.sum(grads * grads)
violation = tau - current_h
theta_star = (violation / (grad_norm_sq + eps)) * grads
q_new = q_temp + (alpha * theta_star)
# Step E: Correction
return q_new.detach(), current_h
# --- 2. Monkey-Patching LlamaAttention ---
def custom_forward_wrapper(original_forward, layer_idx):
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
# --- ROBUST DIMENSION EXTRACTION ---
# FIX: Get hidden_size from input tensor, not 'self'
bsz, q_len, hidden_size = hidden_states.size()
# Get num_heads safely
if hasattr(self, "num_heads"):
num_heads = self.num_heads
elif hasattr(self, "num_attention_heads"):
num_heads = self.num_attention_heads
else:
num_heads = self.config.num_attention_heads
# Get head_dim safely
if hasattr(self, "head_dim"):
head_dim = self.head_dim
else:
head_dim = hidden_size // num_heads
# Get num_key_value_heads safely
if hasattr(self, "num_key_value_heads"):
num_key_value_heads = self.num_key_value_heads
else:
# Fallback to config or default to num_heads
num_key_value_heads = getattr(self.config, "num_key_value_heads", num_heads)
# 1. Standard Projection
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# 2. Apply RoPE
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# 3. Handle Cache
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Repeat KV for GQA (if needed)
key_states = repeat_kv(key_states, num_heads // num_key_value_heads)
value_states = repeat_kv(value_states, num_heads // num_key_value_heads)
# --- 4. INTERVENTION ---
if (STEERING_CONFIG["is_active"] and
layer_idx in STEERING_CONFIG["steer_layers"] and
q_len == 1):
with torch.enable_grad():
q_steered, recorded_h = apply_lcbf_steering(
query_states, key_states, attention_mask, layer_idx
)
query_states = q_steered
# --- 5. Attention ---
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return forward
# --- 3. Experiment Loop ---
class Experiment2_TemporalDrift:
def __init__(self, model, processor, img_dir):
self.model = model
self.processor = processor
self.img_dir = img_dir
print("Injecting LCBF Steering Hooks into Layer 20...")
# Robust Layer Access for LLaVA
if hasattr(self.model, "language_model"):
self.layers = self.model.language_model.layers
elif hasattr(self.model, "model") and hasattr(self.model.model, "layers"):
self.layers = self.model.layers
else:
self.layers = self.model.layers
for i, layer in enumerate(self.layers):
if hasattr(layer, "self_attn"):
layer.self_attn.forward = types.MethodType(
custom_forward_wrapper(layer.self_attn.forward, i),
layer.self_attn
)
def get_img_indices(self, input_ids):
indices = (input_ids[0] == 32000).nonzero(as_tuple=True)[0]
if len(indices) == 0: return 0, 0
return indices[0].item(), indices[-1].item() + 1
def run_pass(self, image, prompt, steering_active=False):
STEERING_CONFIG["is_active"] = steering_active
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
input_ids = inputs.input_ids.to(self.model.device)
pixel_values = inputs.pixel_values.to(self.model.device, dtype=torch.float16)
with torch.inference_mode():
prefill = self.model(input_ids, pixel_values=pixel_values, use_cache=True, return_dict=True)
past_key_values = prefill.past_key_values
start, end = self.get_img_indices(input_ids)
STEERING_CONFIG["img_start"] = start
STEERING_CONFIG["img_end"] = end
next_token = torch.argmax(prefill.logits[:, -1, :], dim=-1).unsqueeze(1)
h_values = []
for t in range(100):
with torch.inference_mode():
outputs = self.model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True,
output_attentions=True
)
# Measure at Layer 20
attns = outputs.attentions[20]
# Re-calculate h(q) for plot
h_val = compute_barrier_value(attns, start, end, STEERING_CONFIG["top_k_heads"]).item()
h_values.append(h_val)
past_key_values = outputs.past_key_values
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1)
if next_token.item() == self.processor.tokenizer.eos_token_id:
break
return h_values
def run_experiment(self, num_images=10):
img_files = sorted([f for f in os.listdir(self.img_dir) if f.endswith('.jpg')])[:num_images]
prompt = "USER: <image>\nDescribe this image in minute detail, mentioning every small object, texture, and color. ASSISTANT:"
baseline_traces = []
steered_traces = []
print(f"Running Analysis on {len(img_files)} images...")
for img_file in tqdm(img_files):
try:
image = Image.open(os.path.join(self.img_dir, img_file)).convert("RGB")
except:
continue
h_base = self.run_pass(image, prompt, steering_active=False)
baseline_traces.append(h_base)
h_steer = self.run_pass(image, prompt, steering_active=True)
steered_traces.append(h_steer)
return baseline_traces, steered_traces
def visualize(self, baseline_traces, steered_traces):
max_len = 100
def pad(t): return t + [np.nan] * (max_len - len(t))
base_mean = np.nanmean(np.array([pad(t) for t in baseline_traces]), axis=0)
steer_mean = np.nanmean(np.array([pad(t) for t in steered_traces]), axis=0)
plt.figure(figsize=(10, 6))
x = range(max_len)
plt.plot(x, base_mean, label='Baseline', color='blue')
plt.plot(x, steer_mean, label='LCBF Steered (L20)', color='orange', linewidth=2)
plt.axhline(y=STEERING_CONFIG["tau"], color='red', linestyle=':')
plt.title('Experiment 2: Temporal Drift')
plt.xlabel('Token Position')
plt.ylabel('Visual Attention Mass h(q)')
plt.legend()
plt.savefig(os.path.join(PROJECT_DIR, "results_exp2_drift.png"))
print(f"Saved results_exp2_drift.png")
if __name__ == "__main__":
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to("cuda")
processor = AutoProcessor.from_pretrained(model_id)
img_dir = os.path.join(os.getcwd(), 'data/mscoco/val2014')
if not os.path.exists(img_dir): sys.exit("Data not found")
exp = Experiment2_TemporalDrift(model, processor, img_dir)
b_traces, s_traces = exp.run_experiment(num_images=20)
exp.visualize(b_traces, s_traces)
Questions:
-
Attribute Access: Why does
self.rotary_emb(and previouslynum_heads/hidden_size) fail inside the monkey-patched forward method forLlamaAttentionin LLaVA-1.5? Is there a different standard attribute for RoPE in recent transformers versions? -
Alternative Approach: Since I need to calculate gradients w.r.t Q to steer it, standard forward hooks (which operate on outputs) are insufficient. Is there a cleaner, more stable way to intercept the input to the dot-product attention (post-RoPE) without rewriting the entire forward pass?