[LLaVA-1.5] Implementing Control Barrier Functions (LCBF) via Attention Hooking – Persistent AttributeError: 'LlamaAttention' object has no attribute 'rotary_emb'

Context & Objective

I am implementing a research experiment on Linearized Control Barrier Functions (LCBF) to reduce hallucinations in VLM generation (specifically llava-hf/llava-1.5-7b-hf).

  • Goal: Implement a “Steering Controller” that intervenes during the forward pass of specific transformer layers (e.g., Layer 20).

  • Mechanism:

    1. Intercept the Query (Q) and Key (K) tensors after projection and RoPE, but before the attention dot product.

    2. Calculate a “Safety Value” h(Q) defined as the attention mass on visual tokens (Top-K heads).

    3. If h(Q) < \\tau (unsafe drift), calculate \\nabla_Q h(Q) and apply a closed-form steering vector \\theta^\* to Q.

    4. Resume the forward pass with the steered Q\_{new}.

The Implementation Strategy

To achieve this without retraining, I am monkey-patching the forward method of LlamaAttention modules. My wrapper function attempts to replicate the standard attention logic (Projection \\rightarrow RoPE \\rightarrow Cache \\rightarrow Dot Product) but inserts my steering logic right before the dot product.

The Problem

I am encountering persistent AttributeErrors related to accessing internal modules of LlamaAttention, specifically rotary_emb. It appears the attributes exposed by the self object inside the patched method do not match the expected LlamaAttention definition.

I previously faced similar errors for self.num_heads and self.hidden_size, which I resolved by deriving dimensions directly from the input tensors. However, rotary_emb is a generic module (likely LlamaRotaryEmbedding), and I cannot derive it from tensors.
File “experiment2_drift.py”, line 149, in forward
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
AttributeError: ‘LlamaAttention’ object has no attribute ‘rotary_emb’

My Code Below is the reproduction script. It includes the custom_forward_wrapper where I attempt to recreate the forward pass logic.

import torch

from transformers import AutoProcessor, LlavaForConditionalGeneration




import torch.nn as nn

import os

import numpy as np

import matplotlib.pyplot as plt

from tqdm import tqdm

from PIL import Image

import types

import math

import sys




# --- Robust Imports for Llama Utilities ---

try:

from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

except ImportError:

# Fallback for newer transformers versions where path might differ

from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv




# --- Setup Paths ---

PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

sys.path.append(PROJECT_DIR)

from data_setup import load_llava_model




# --- Configuration ---

STEERING_CONFIG = {

"tau": 0.20,             # Safety Threshold

"alpha": 1.0,            # Step size

"epsilon": 1e-6,         # Stability constant

"top_k_heads": 5,        # Dynamic Top-K aggregation

"steer_layers": [20],    # Steer ONLY Layer 20

"is_active": False,      # Global Toggle

"img_start": 0,          # Dynamic

"img_end": 0             # Dynamic

}




# --- 1. The LCBF Steering Logic ---




def compute_barrier_value(attn_weights, img_start, img_end, top_k):

"""

    Calculates h(q): Mean Attention Mass of the Top-K visual heads.

    attn_weights shape: (bsz, num_heads, q_len, k_len)

    """

# Extract attention on visual tokens for the last query step

# Shape: (num_heads, visual_tokens)

# We use batch index 0 (assuming bsz=1)

    visual_attn = attn_weights[0, :, -1, img_start:img_end]

# Sum across visual tokens -> Mass per Head

    mass_per_head = visual_attn.sum(dim=-1)

# Select Top-K heads

    k = min(top_k, mass_per_head.shape[0])

    top_masses, _ = torch.topk(mass_per_head, k=k)

# Barrier value is the mean of these Top-K heads

return top_masses.mean()




def apply_lcbf_steering(query_states, key_states, attention_mask, layer_idx):

"""

    The Intervention Loop:

    1. Observe Q

    2. Check h(Q) < tau

    3. If unsafe, compute grad and steer Q -> Q_new

    """

    tau = STEERING_CONFIG["tau"]

    alpha = STEERING_CONFIG["alpha"]

    eps = STEERING_CONFIG["epsilon"]

    img_start = STEERING_CONFIG["img_start"]

    img_end = STEERING_CONFIG["img_end"]

    top_k = STEERING_CONFIG["top_k_heads"]

# Step A: Observation

    q_temp = query_states.clone().detach().requires_grad_(True)

# Step B: Safety Check

    head_dim = q_temp.shape[-1]

    k_transposed = key_states.transpose(-1, -2)

    attn_scores = torch.matmul(q_temp, k_transposed) / math.sqrt(head_dim)

if attention_mask is not None:

        attn_scores = attn_scores + attention_mask

    attn_weights = nn.functional.softmax(attn_scores, dim=-1)

    h_val = compute_barrier_value(attn_weights, img_start, img_end, top_k)

    current_h = h_val.item()

if current_h >= tau:

return query_states, current_h

# Step C: Gradient

    grads = torch.autograd.grad(h_val, q_temp, retain_graph=False)[0]

# Step D: Steering Update (LiSeCo)

    grad_norm_sq = torch.sum(grads * grads)

    violation = tau - current_h

    theta_star = (violation / (grad_norm_sq + eps)) * grads

    q_new = q_temp + (alpha * theta_star)

# Step E: Correction

return q_new.detach(), current_h




# --- 2. Monkey-Patching LlamaAttention ---




def custom_forward_wrapper(original_forward, layer_idx):

def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs):

# --- ROBUST DIMENSION EXTRACTION ---

# FIX: Get hidden_size from input tensor, not 'self'

        bsz, q_len, hidden_size = hidden_states.size()

# Get num_heads safely

if hasattr(self, "num_heads"):

            num_heads = self.num_heads

elif hasattr(self, "num_attention_heads"):

            num_heads = self.num_attention_heads

else:

            num_heads = self.config.num_attention_heads




# Get head_dim safely

if hasattr(self, "head_dim"):

            head_dim = self.head_dim

else:

            head_dim = hidden_size // num_heads




# Get num_key_value_heads safely

if hasattr(self, "num_key_value_heads"):

            num_key_value_heads = self.num_key_value_heads

else:

# Fallback to config or default to num_heads

            num_key_value_heads = getattr(self.config, "num_key_value_heads", num_heads)

# 1. Standard Projection

        query_states = self.q_proj(hidden_states)

        key_states = self.k_proj(hidden_states)

        value_states = self.v_proj(hidden_states)




        query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

        key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)

        value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)




        kv_seq_len = key_states.shape[-2]

if past_key_value is not None:

            kv_seq_len += past_key_value[0].shape[-2]

# 2. Apply RoPE

        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)




# 3. Handle Cache

if past_key_value is not None:

            key_states = torch.cat([past_key_value[0], key_states], dim=2)

            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None




# Repeat KV for GQA (if needed)

        key_states = repeat_kv(key_states, num_heads // num_key_value_heads)

        value_states = repeat_kv(value_states, num_heads // num_key_value_heads)




# --- 4. INTERVENTION ---

if (STEERING_CONFIG["is_active"] and 

            layer_idx in STEERING_CONFIG["steer_layers"] and 

            q_len == 1):

with torch.enable_grad():

                q_steered, recorded_h = apply_lcbf_steering(

                    query_states, key_states, attention_mask, layer_idx

                )

            query_states = q_steered




# --- 5. Attention ---

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

if attention_mask is not None:

            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, hidden_size)

        attn_output = self.o_proj(attn_output)




if not output_attentions:

            attn_weights = None




return attn_output, attn_weights, past_key_value

return forward




# --- 3. Experiment Loop ---




class Experiment2_TemporalDrift:

def __init__(self, model, processor, img_dir):

self.model = model

self.processor = processor

self.img_dir = img_dir

        print("Injecting LCBF Steering Hooks into Layer 20...")

# Robust Layer Access for LLaVA

if hasattr(self.model, "language_model"):

self.layers = self.model.language_model.layers

elif hasattr(self.model, "model") and hasattr(self.model.model, "layers"):

self.layers = self.model.layers

else:

self.layers = self.model.layers




for i, layer in enumerate(self.layers):

if hasattr(layer, "self_attn"):

                layer.self_attn.forward = types.MethodType(

                    custom_forward_wrapper(layer.self_attn.forward, i), 

                    layer.self_attn

                )




def get_img_indices(self, input_ids):

        indices = (input_ids[0] == 32000).nonzero(as_tuple=True)[0]

if len(indices) == 0: return 0, 0

return indices[0].item(), indices[-1].item() + 1




def run_pass(self, image, prompt, steering_active=False):

        STEERING_CONFIG["is_active"] = steering_active

        inputs = self.processor(text=prompt, images=image, return_tensors="pt")

        input_ids = inputs.input_ids.to(self.model.device)

        pixel_values = inputs.pixel_values.to(self.model.device, dtype=torch.float16)

with torch.inference_mode():

            prefill = self.model(input_ids, pixel_values=pixel_values, use_cache=True, return_dict=True)

            past_key_values = prefill.past_key_values

        start, end = self.get_img_indices(input_ids)

        STEERING_CONFIG["img_start"] = start

        STEERING_CONFIG["img_end"] = end

        next_token = torch.argmax(prefill.logits[:, -1, :], dim=-1).unsqueeze(1)

        h_values = []

for t in range(100):

with torch.inference_mode():

                outputs = self.model(

                    input_ids=next_token,

                    past_key_values=past_key_values,

                    use_cache=True,

                    output_attentions=True

                )

# Measure at Layer 20

            attns = outputs.attentions[20]

# Re-calculate h(q) for plot

            h_val = compute_barrier_value(attns, start, end, STEERING_CONFIG["top_k_heads"]).item()

            h_values.append(h_val)

            past_key_values = outputs.past_key_values

            next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1)

if next_token.item() == self.processor.tokenizer.eos_token_id:

break

return h_values




def run_experiment(self, num_images=10):

        img_files = sorted([f for f in os.listdir(self.img_dir) if f.endswith('.jpg')])[:num_images]

        prompt = "USER: <image>\nDescribe this image in minute detail, mentioning every small object, texture, and color. ASSISTANT:"

        baseline_traces = []

        steered_traces = []

        print(f"Running Analysis on {len(img_files)} images...")

for img_file in tqdm(img_files):

try:

                image = Image.open(os.path.join(self.img_dir, img_file)).convert("RGB")

except:

continue

            h_base = self.run_pass(image, prompt, steering_active=False)

            baseline_traces.append(h_base)

            h_steer = self.run_pass(image, prompt, steering_active=True)

            steered_traces.append(h_steer)

return baseline_traces, steered_traces




def visualize(self, baseline_traces, steered_traces):

        max_len = 100

def pad(t): return t + [np.nan] * (max_len - len(t))

        base_mean = np.nanmean(np.array([pad(t) for t in baseline_traces]), axis=0)

        steer_mean = np.nanmean(np.array([pad(t) for t in steered_traces]), axis=0)

        plt.figure(figsize=(10, 6))

        x = range(max_len)

        plt.plot(x, base_mean, label='Baseline', color='blue')

        plt.plot(x, steer_mean, label='LCBF Steered (L20)', color='orange', linewidth=2)

        plt.axhline(y=STEERING_CONFIG["tau"], color='red', linestyle=':')

        plt.title('Experiment 2: Temporal Drift')

        plt.xlabel('Token Position')

        plt.ylabel('Visual Attention Mass h(q)')

        plt.legend()

        plt.savefig(os.path.join(PROJECT_DIR, "results_exp2_drift.png"))

        print(f"Saved results_exp2_drift.png")




if __name__ == "__main__":

    model_id = "llava-hf/llava-1.5-7b-hf"

    model = LlavaForConditionalGeneration.from_pretrained(

        model_id, 

        torch_dtype=torch.float16, 

        low_cpu_mem_usage=True,

        attn_implementation="eager",

    ).to("cuda")

    processor = AutoProcessor.from_pretrained(model_id)

    img_dir = os.path.join(os.getcwd(), 'data/mscoco/val2014')

if not os.path.exists(img_dir): sys.exit("Data not found")

    exp = Experiment2_TemporalDrift(model, processor, img_dir)

    b_traces, s_traces = exp.run_experiment(num_images=20)

    exp.visualize(b_traces, s_traces)

Questions:

  1. Attribute Access: Why does self.rotary_emb (and previously num_heads/hidden_size) fail inside the monkey-patched forward method for LlamaAttention in LLaVA-1.5? Is there a different standard attribute for RoPE in recent transformers versions?

  2. Alternative Approach: Since I need to calculate gradients w.r.t Q to steer it, standard forward hooks (which operate on outputs) are insufficient. Is there a cleaner, more stable way to intercept the input to the dot-product attention (post-RoPE) without rewriting the entire forward pass?

1 Like

I tested it with Transformers 5.10, but it might happen even with versions other than 5…?


Background: why this keeps happening in modern Transformers

The LLaVA-1.5 text backbone is a LLaMA-family decoder. In recent Transformers releases, the LLaMA stack has undergone a large refactor around:

  • where RoPE is computed (model-level vs per-layer),
  • how KV caching works (tuple caches vs Cache.update(...)),
  • which attention backend runs (eager / SDPA / Flash / Flex),
  • and which attributes live on the attention module vs config/parent.

Those changes make “re-implement LlamaAttention.forward by copying old code” fragile: small version differences translate into missing attributes and signature mismatches.


1) Why self.rotary_emb (and num_heads/hidden_size) fails inside your monkey-patched forward

rotary_emb is not guaranteed to exist on LlamaAttention

In newer LLaMA implementations, RoPE is often computed once at the model level and passed down to each attention call as position_embeddings=(cos, sin) rather than being computed in every attention layer. The reference code shows attention preferring the passed-in position_embeddings and only using self.rotary_emb as a fallback during the transition. (Hugging Face)

As a result, depending on the exact Transformers version / model variant, LlamaAttention may not have a rotary_emb attribute—so any patched forward that unconditionally calls self.rotary_emb(...) will raise an AttributeError. This exact error pattern is reported in upstream issues. (GitHub)

Your wrapper also assumes older RoPE call signatures

You call:

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
apply_rotary_pos_emb(..., cos, sin, position_ids)

But modern code paths typically:

  • compute RoPE using position_ids (not seq_len=), and
  • call apply_rotary_pos_emb(q, k, cos, sin) without passing position_ids (because cos/sin are already aligned). (GitHub)

So there are two failure modes: missing attribute and/or incompatible signature.

num_heads / hidden_size are also not stable “instance attributes”

Transformers increasingly stores head counts and sizes on the config or computes them indirectly, especially with grouped-query attention and multiple backends. There are multiple reports of num_heads disappearing from LlamaAttention across versions. (GitHub)

Takeaway: the patched self is not “wrong”; the internal contract you’re assuming is no longer stable.


2) Cleaner, stable ways to intercept post-RoPE, pre-dot-product (and still compute gradients w.r.t. Q)

Recommended: use AttentionInterface (custom attention function)

Transformers provides an official way to replace only the attention computation (the part that consumes query, key, value, attention_mask) while keeping projection/RoPE/cache logic inside the model unchanged. (Hugging Face)

Why it matches your controller:

  • You receive Q/K/V right where dot-product attention happens (the seam you want).
  • You can compute your “visual-attention mass” metric from scores/weights.
  • You can compute gradients w.r.t. query locally and then call the reference eager attention implementation afterward.

One critical pitfall: mask plumbing.
If your custom attention needs a mask, you may also need to register the corresponding mask-creation function; otherwise, mask handling can change or be dropped. This is called out in practice and discussed in an issue. (GitHub)

If you insist on monkey-patching forward: patch the new signature, not the legacy one

To keep forward monkey-patching viable across versions, your patched forward must:

  • accept position_embeddings (and use it when present),
  • accept cache_position,
  • support cache objects that want .update(...) with {"sin","cos","cache_position"} in cache_kwargs. (Hugging Face)

This is substantially more work than it looks, and it is why patching the attention function (AttentionInterface) tends to be more stable than patching the attention module forward.


Two additional “gotchas” that directly affect your experiment

A) Gradients won’t behave as intended under torch.inference_mode()

If you run the model forward in torch.inference_mode(), tensors produced inside are not meant to participate in autograd later. For local gradient computations during generation, use torch.no_grad() for the outer call and torch.enable_grad() only around the small steering computation. (swebench.com)

B) “Visual token span” in LLaVA is not just the <image> placeholder position

LLaVA expands the <image> placeholder into hundreds of image-feature tokens (documentation notes it’s usually ~500 per image). If you compute your “visual attention mass” using indices based only on the placeholder token location, it can be misaligned with the actual KV sequence the model attends over. (Hugging Face)

This matters because your barrier signal needs to reflect attention to the expanded image-token region, not the placeholder marker.


Practical recommendation for your exact goal (minimal rewriting, maximum stability)

  1. Move the intervention into a custom attention function via AttentionInterface. (Hugging Face)
  2. Keep attn_implementation="eager" initially (predictable semantics), and ensure mask behavior is preserved (register mask creation if needed). (GitHub)
  3. Run steered passes without inference_mode if you need ∂h/∂Q. (swebench.com)
  4. Compute the correct image-token span based on LLaVA’s expansion/merge behavior rather than only the placeholder index. (Hugging Face)

BTW, code for verification.


"""
Repro + Fix demo: "AttributeError: 'LlamaAttention' object has no attribute 'rotary_emb'"
(single file, no argparse)

What this script does
1) Loads a small LLaMA model (tiny) so it runs on CPU/GPU (T4-safe).
2) Monkey-patches ONE attention layer with a legacy forward that assumes `self.rotary_emb` exists.
   -> On newer Transformers, this often raises AttributeError.
3) Monkey-patches again with a forward that is compatible with the refactor:
   - Prefer `position_embeddings=(cos, sin)` passed from the model (new style).
   - Fallback to model-level `model.model.rotary_emb` if `position_embeddings` is None.
   -> Forward succeeds.

Dependencies
  pip install torch transformers accelerate

Notes / references (URLs requested)
- Why this breaks: RoPE moved to model-level `rotary_emb`, attention gets `position_embeddings=(cos,sin)`:
  https://raw.githubusercontent.com/huggingface/transformers/main/src/transformers/models/llama/modeling_llama.py
- Report of the exact error:
  https://github.com/huggingface/transformers/issues/36758
- Stable interception alternative (recommended for research): AttentionInterface
  https://huggingface.co/docs/transformers/v4.53.1/attention_interface
- Gradient note: torch.inference_mode is not compatible with later autograd usage:
  https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad_mode.inference_mode.html
"""

import os
import sys
import types
import torch

def pick_device_and_dtype():
    if torch.cuda.is_available():
        return torch.device("cuda"), torch.float16
    return torch.device("cpu"), torch.float32  # float32 on CPU (requested)

@torch.no_grad()
def run_one_forward(model, tokenizer, device):
    text = "Hello from the repro script."
    inputs = tokenizer(text, return_tensors="pt").to(device)
    out = model(**inputs)  # just to exercise attention
    # print a small checksum so you can confirm "it ran"
    logits = out.logits
    print(f"  forward ok. logits checksum={float(logits.float().mean()):.6f}")

def legacy_forward_repro():
    """
    Old-style patch that assumes:
      - attention module has self.rotary_emb
      - rotary_emb signature accepts seq_len kwarg
    Newer Transformers refactors often remove attention-local rotary_emb.
    """
    import math
    import torch.nn as nn
    from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

    def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None,
                output_attentions=False, use_cache=False, **kwargs):
        # Minimal re-impl (enough to trigger the rotary_emb access)
        bsz, q_len, hidden_size = hidden_states.shape

        # derive heads from config if needed
        num_heads = getattr(self, "num_heads", None) or getattr(self, "num_attention_heads", None) or self.config.num_attention_heads
        head_dim = getattr(self, "head_dim", hidden_size // num_heads)

        num_kv_heads = getattr(self, "num_key_value_heads", getattr(self.config, "num_key_value_heads", num_heads))

        q = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)

        kv_seq_len = k.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        # --- This line is the repro: fails if `self.rotary_emb` does not exist ---
        cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)

        # (If it did exist, next line can still mismatch due to signature drift)
        q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)

        if past_key_value is not None:
            k = torch.cat([past_key_value[0], k], dim=2)
            v = torch.cat([past_key_value[1], v], dim=2)

        # repeat kv for GQA
        k = repeat_kv(k, num_heads // num_kv_heads)
        v = repeat_kv(v, num_heads // num_kv_heads)

        attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
        if attention_mask is not None:
            attn = attn + attention_mask
        attn = nn.functional.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(bsz, q_len, hidden_size)
        out = self.o_proj(out)

        if not output_attentions:
            attn = None

        return out, attn, (k, v) if use_cache else None

    return forward

def fixed_forward_patch(model_rotary_emb):
    """
    Fix patch:
      - supports the refactor where attention receives position_embeddings=(cos,sin)
      - if position_embeddings is missing, fall back to model-level rotary_emb (preferred location now)

    This matches current upstream direction where LlamaAttention.forward unpacks `position_embeddings`
    and LlamaModel owns `self.rotary_emb`.
    """
    import torch.nn as nn
    from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward

    def forward(
        self,
        hidden_states,
        position_embeddings=None,       # NEW STYLE: (cos, sin)
        attention_mask=None,
        past_key_values=None,           # NEW STYLE cache object or None
        cache_position=None,
        **kwargs,
    ):
        input_shape = hidden_states.shape[:-1]  # (bsz, q_len)
        head_dim = getattr(self, "head_dim", None) or (hidden_states.shape[-1] // self.config.num_attention_heads)
        hidden_shape = (*input_shape, -1, head_dim)

        q = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        k = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        v = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        # Prefer externally provided RoPE (cos,sin)
        if position_embeddings is None:
            # Fallback: model-level rotary embedding (the new standard location)
            # Try both common calling conventions.
            position_ids = kwargs.get("position_ids", None)
            if model_rotary_emb is None:
                raise RuntimeError(
                    "position_embeddings was None and model_rotary_emb was not found. "
                    "Your Transformers version/model likely requires position_embeddings."
                )
            try:
                position_embeddings = model_rotary_emb(hidden_states, position_ids=position_ids)
            except TypeError:
                position_embeddings = model_rotary_emb(hidden_states, position_ids)

        cos, sin = position_embeddings

        # apply_rotary_pos_emb signature differs across versions; try new-style first
        try:
            q, k = apply_rotary_pos_emb(q, k, cos, sin)
        except TypeError:
            # older signature may include position_ids
            q, k = apply_rotary_pos_emb(q, k, cos, sin, kwargs.get("position_ids", None))

        # Keep this demo simple: ignore cache updates (past_key_values) and just run eager attention.
        # (For real generation+cache, follow upstream cache.update(...) patterns.)
        attn_out, attn_w = eager_attention_forward(
            self,
            q, k, v,
            attention_mask=attention_mask,
            scaling=getattr(self, "scaling", head_dim ** -0.5),
            dropout=0.0 if not self.training else getattr(self, "attention_dropout", 0.0),
            **kwargs,
        )
        attn_out = attn_out.reshape(*input_shape, -1).contiguous()
        attn_out = self.o_proj(attn_out)
        return attn_out, attn_w

    return forward

def main():
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import transformers

    device, dtype = pick_device_and_dtype()
    print(f"transformers={transformers.__version__}")
    print(f"device={device}, dtype={dtype}")

    # Small model so it runs quickly on CPU and fits easily on T4
    model_id = os.environ.get("MODEL_ID", "hf-internal-testing/tiny-random-LlamaForCausalLM")
    print(f"model_id={model_id}")

    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device).eval()
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    except Exception:
        # fallback tokenizer sometimes needed for tiny internal models
        tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer", use_fast=True)

    # Baseline forward
    print("\n[0] Baseline (no patch)")
    run_one_forward(model, tokenizer, device)

    # Locate one LlamaAttention module (layer 0 self_attn)
    # Works for standard LlamaForCausalLM: model.model.layers[i].self_attn
    attn = model.model.layers[0].self_attn
    print(f"\nFound attention module class: {attn.__class__.__name__}")
    print(f"hasattr(attn, 'rotary_emb') = {hasattr(attn, 'rotary_emb')}")

    # -------------------------------------------------------------------------
    # [1] Repro: patch with legacy forward that calls self.rotary_emb(...)
    # -------------------------------------------------------------------------
    print("\n[1] Repro patch: legacy forward assumes self.rotary_emb exists")
    orig_forward = attn.forward
    attn.forward = types.MethodType(legacy_forward_repro(), attn)

    try:
        run_one_forward(model, tokenizer, device)
        print("  (If you did NOT see an exception, your current Transformers build still exposes rotary_emb "
              "or accepted the legacy call; the fix below is still the robust pattern.)")
    except Exception as e:
        print("  Expected failure:")
        print(f"    {type(e).__name__}: {e}")

    # Restore
    attn.forward = orig_forward

    # -------------------------------------------------------------------------
    # [2] Fix: patch with forward that consumes position_embeddings or uses model.model.rotary_emb
    # -------------------------------------------------------------------------
    print("\n[2] Fix patch: use position_embeddings (cos,sin) or model-level rotary_emb fallback")
    model_rotary_emb = getattr(model.model, "rotary_emb", None)
    attn.forward = types.MethodType(fixed_forward_patch(model_rotary_emb), attn)

    try:
        run_one_forward(model, tokenizer, device)
        print("  Fix succeeded.")
    except Exception as e:
        print("  Fix failed in your environment:")
        print(f"    {type(e).__name__}: {e}")
        print("  If this happens, your installed Transformers/model signature may differ; "
              "use AttentionInterface (link at top) to avoid forward re-implementation.")

if __name__ == "__main__":
    torch.manual_seed(0)
    main()

References for the refactor and stability rationale: the upstream LLaMA implementation shows attention receiving position_embeddings=(cos,sin) and model-level rotary_emb ownership. (GitHub) The missing-rotary_emb symptom is documented in a Transformers issue. (GitHub) The stable alternative interception point is AttentionInterface. (huggingface.co)