TensorFlow port of HF’s Paligemma

I assumed many would port such models to TF to learn but I didn’t find any repos. Mine is GitHub - mohanr/Paligemma

It is supposed to be the same as transformers/src/transformers/models/siglip at main · huggingface/transformers · GitHub

The problem is that the tokens are wrong even though they are different for different images.

I did compare weights for all layers and it could be a computation problem that slightly assigns wrong logits to some tokens. Isn’t there a way to debug such complex models ?

Has anyone ported like this ? I thought NNSight/NDIF could help but model may be too complex for that.

Are these Keras layers so tricky to work with ?

‘inference.sh’ is the driver. I did ask the LLM to debug it as I only replaced PyTorch with TF. Didn’t write any new code at all.

I also wanted to ask about debugging techniques. How does one debug this ? This is inference on a laptop.

I have executed this and tested so many times. Really appreciate any help. I am executing on a M4 Mac.

Thanks

Seems there is a port on Kaggle…?


What “tokens are wrong but change with the image” usually means (in cross-framework ports)

When weights match layer-by-layer but generation is consistently “off,” the cause is almost always one of these:

  1. Input/preprocessing mismatch (most common in VLM ports)

    • Channel order / data format mismatch: HF PaliGemma expects pixel_values shaped like (batch, channels, H, W) in its PyTorch path. (Hugging Face)
      If your TF model is NHWC but you feed NCHW (or vice versa), you’ll get image-dependent outputs that are consistently wrong.
    • Normalization mismatch: HF SigLIP image processor defaults to rescale_factor=1/255, image_mean=[0.5]*3, image_std=[0.5]*3, and RGB conversion. (Hugging Face)
      If you normalize like CLIP (mean/std ~ ImageNet) or omit rescaling, logits can shift enough to flip argmax early.
  2. Masking / cache / position handling mismatch (very common)

    • PaliGemma uses full block attention over image tokens + input text tokens. (Hugging Face)
    • HF’s implementation also has special handling described as a bidirectional mask on prompt tokens (and causal behavior for generated tokens). (GitHub)
    • A real HF bug showed use_cache=True broke generation due to an attention-mask computation issue; outputs went wrong even though the model “worked.” (Hugging Face)
      If your TF port has a KV cache, an off-by-one in cache_position / position_ids / mask broadcasting is enough to derail tokens.
  3. Attention shape logic mismatch: Gemma uses “different Q vs K/V sizes”

    • In the reference architecture, q_proj has output 2048 while k_proj/v_proj are much smaller (e.g., 256). (Google Developers Blog)
      That implies grouped-query / multi-query attention behavior. A TF port that reshapes K/V as if they had the same head count as Q (or repeats the wrong axis) will produce systematically wrong logits while weights still “match.”
  4. Activation function mismatch (quiet but impactful)

    • Both SigLIP and Gemma commonly use GELU tanh approximation (gelu_pytorch_tanh / PytorchGELUTanh). (Hugging Face)
      Using “exact” GELU vs tanh-approx can change logits enough to flip tokens, especially early in decoding.
  5. Precision / backend differences (Mac Metal can amplify)

    • There are reports of inconsistent results on Mac M4 vs NVIDIA in Keras. (GitHub)
      This usually shouldn’t totally scramble tokens on its own, but it can make debugging harder if you’re already near decision boundaries.

The fastest way to debug your port (what I would do with your constraints)

The winning strategy is binary search on the forward pass, before you touch sampling/decoding.

Step 0 — Make it deterministic and remove “generation complexity”

Do this first:

  • Force greedy decoding (argmax), no temperature/top-p.

  • Run in float32 everywhere for debugging.

  • Temporarily disable KV cache (use_cache=False equivalent): recompute full forward each step.

    • If outputs become correct (or much closer), the bug is in cache/mask/positions (matches the HF use_cache failure mode). (Hugging Face)

Step 1 — Lock inputs so preprocessing can’t be the culprit

In PyTorch/HF:

  • Use the official processor and save the exact tensors you feed the model:

    • input_ids, attention_mask
    • pixel_values (as produced by the processor)
  • Then in TF, load those saved arrays and run your TF model on them.

Why: this eliminates every difference in resizing/normalization/tokenization in one move. HF’s SigLIP processor defaults are easy to miss. (Hugging Face)

If TF outputs are still wrong using HF-produced pixel_values and input_ids, preprocessing is not the problem.

Step 2 — Compare intermediate activations (layerwise “tripwires”)

You want the first layer where TF diverges from PyTorch.

Do it in this order (cheap → expensive):

  1. Vision tower output (end of SigLIP)

  2. Multimodal projector output

  3. Text token embeddings (embedding table lookup)

  4. One decoder layer at a time:

    • input RMSNorm output
    • Q/K/V tensors (after projection, after reshape)
    • attention scores (pre-softmax)
    • attention probs (post-softmax)
    • attention output projection
    • MLP pre-activation, post-activation, output
  5. Final logits

This is exactly the “single forward pass validation / binary search” workflow TF recommends for migrations: narrow scope by checking equivalence at intermediate steps. (TensorFlow)

Practical note: for each checkpoint, compute:

  • max absolute diff
  • mean absolute diff
  • cosine similarity (for large vectors)
    and log shapes/dtypes.

Step 3 — If divergence starts inside attention, check these specific traps

Given Gemma’s Q vs KV projection sizes (Google Developers Blog), I would audit:

  • Head math

    • q: (B, T, n_heads, head_dim)
    • k/v: (B, T, n_kv_heads, head_dim)
    • Then repeat/broadcast k/v across query heads (grouped-query logic).
  • Transpose conventions (TF often uses (B, heads, T, head_dim) vs (B, T, heads, head_dim))
    One wrong transpose produces “valid-looking” tensors and totally wrong logits.

  • Mask application point
    Mask must be added to attention scores before softmax with a large negative value.

  • RoPE / positions
    In HF forward signature you’ll see cache_position and position_ids concerns. (Hugging Face)
    With caching, position handling is the #1 off-by-one source.

Step 4 — Only after forward-pass matches, debug generation

Once a single forward pass matches closely, then:

  • Implement greedy generation without cache (slow but correct).

  • Add cache back.

  • Re-check equivalence step-by-step on:

    • step 0 logits
    • step 1 logits given past
    • step 2 logits given past, etc.

If cache breaks it, compare against the known HF failure mode: use_cache=True causing wrong outputs due to attention-mask computation. (Hugging Face)


“Similar cases” and issues worth reading (directly relevant)

Cache/mask issues in PaliGemma generation

  • HF issue/discussion: use_cache=True breaks PaliGemma generation (attention mask miscomputed; outputs wrong). (Hugging Face)
    This is extremely aligned with your symptoms if you have a KV cache.

Prompt formatting pitfalls (especially newline + ordering)

  • Keras PaliGemma model card examples include a trailing newline in prompts (e.g., "caption en\n"). (Hugging Face)
  • Google’s prompt-format guide highlights the ordering: image first, then prompt text.
  • HF forum discussion about newline tokenization behavior (easy to misunderstand when copying prompts between implementations). (Hugging Face Forums)

Weight conversion / shape pitfalls (good for “what can go wrong”)

  • HF issue about converting PaliGemma NPZ → HF hit a reshape mismatch in attention projection handling. (GitHub)
    Not your exact path, but it’s another data point that projection shapes/head reshaping are a common failure point.

Mac-specific “equivalence noise”

  • Keras issue reporting inconsistent results on Mac M4 vs NVIDIA. (GitHub)
    For debugging, run CPU/float32 first to avoid backend-specific numeric quirks.

“Has anyone ported it to TF?”

Yes—just not necessarily as a HF-style TF transformers model:

  • Google’s official “inference with Keras” guide uses PaliGemmaCausalLM from Keras Hub. (Google AI for Developers)
  • Kaggle hosts a Keras implementation of PaliGemma 2 that runs on JAX, TensorFlow, and PyTorch (Keras 3 multi-backend). (kaggle.com)
  • HF also hosts Keras-formatted checkpoints (your earlier link shows a “*-keras” variant). (Hugging Face)

If your goal is “TF inference on laptop,” using the Keras Hub model as a reference oracle is valuable even if you keep your own port.


Tooling: how to debug “complex TF models” in practice

1) TensorBoard Debugger V2 (for shapes/NaNs/Inf and execution traces)

TensorFlow provides a debugger workflow via tf.debugging.experimental.enable_dump_debug_info(...) to inspect tensor health, shapes, and execution history. (TensorFlow)
This won’t directly tell you “your transpose is wrong,” but it helps catch silent numeric pathologies and gives you visibility into what ran.

2) TF migration debugging playbook (applies directly to PyTorch→TF ports)

TF’s migration debugging guide explicitly recommends:

  • single forward-pass equivalence
  • disabling randomness
  • binary search over the model to localize divergence (TensorFlow)
    That workflow maps almost perfectly to debugging a transformer port.

3) NNsight / NDIF (useful, but not the first hammer here)

NNsight/NDIF is great for activation patching and interpretability workflows. (arXiv)
For your problem (“my port is numerically wrong”), you’ll usually get faster answers from layerwise numeric equivalence tests than from interpretability tooling.


The single most likely bug class in your specific case

Given:

  • weights match,
  • outputs are image-dependent but wrong,
  • you likely implemented KV cache + attention yourself (your repo mentions KVCache and custom Gemma bits),

…the highest-probability root cause is:

Grouped-query/multi-query attention + cache_position/mask handling (shape/broadcast/off-by-one).

The architectural clue is that Q and KV projection sizes differ (e.g., q_proj: 2048, k_proj/v_proj: 256). (Google Developers Blog)
If KV heads are mishandled, every subsequent token distribution will be “reasonable-looking” but wrong.

Second most likely:

Third:

  • pixel_values layout mismatch (NCHW vs NHWC). (Hugging Face)

If you want a very concrete next move

  1. Disable cache and do 1-step greedy decoding.

  2. Feed TF the exact pixel_values and input_ids saved from HF.

  3. Compare:

    • projector output
    • logits at step 0

That 3-step test will tell you, with high confidence, whether the bug is:

  • preprocessing/layout,
  • core forward pass (attention/MLP),
  • or cache/mask/positions.

Thanks. I believe I have tested some of the things already. The KVCache and RoPE were probably the most difficult to debug. Not sure if I can switch off the cache as it is entwined with the code. All the code is exactly the same as PyTorch’s Paligemma implementations except the errors introduced by me.
The porting mechanism was simple because both the code structures are similar.

Will report back if there is any improvement.

The links are helpful.

1 Like

This turned out to be a set of errors introduced by me. But there was a critical bug related to KVCache. Layer prefilling mechanism was skipped for all layers except 0.

Moreovere I was loading 24 vision and not 27 Vision layers. And mismatch between HF’s GemmaConfig and mine.

The prompt is “caption en What is this ? And it generates the correct token for my test images.

import tensorflow as tf

class SiglipVisionConfig():
    def __init__(self,
                 hidden_size = 1152,
                 intermediate_size = 4304,
                 num_hidden_layers = 27,
                 num_attention_heads = 16,
                 num_channels = 3,
                 image_size = 448,
                 patch_size = 14,
                 layer_norm_eps = 1e-6,
                 attention_dropout = 0.0,
                 projection_dim=2048,
                 num_image_tokens : int = 1024,
                 max_position_embeddings=1024
                 ):
        super().__init__()
        self.projection_dim=projection_dim
        self.hidden_size=hidden_size
        self.intermediate_size=intermediate_size
        self.num_hidden_layers=num_hidden_layers
        self.num_attention_heads=num_attention_heads
        self.num_channels=num_channels
        self.image_size=image_size
        self.patch_size=patch_size
        self.layer_norm_eps=layer_norm_eps
        self.attention_dropout=attention_dropout
        self.num_image_tokens=num_image_tokens
        self.max_position_embeddings=max_position_embeddings
         

1 Like

Since PaliGemma seems to have strict prompt rules, just in case for future readers…


Key context: PaliGemma is prefix-LM (this drives many “mysterious” port bugs)

PaliGemma concatenates image tokens first, then BOS, then prefix text, then a SEP token implemented as \n, and generates the suffix autoregressively. The paper explicitly describes:

  • Full (unmasked) attention over image + prefix tokens (so image tokens can “look ahead” to the question).
  • Autoregressive masking over the suffix/output tokens.
  • The newline \n is the SEP token separating prefix/suffix.

That design means:

  • If your prefill is wrong (cache, position ids, attention masking, tokenization), you can get “close but wrong” first-token logits.
  • If your \n handling is off, you can shift where prefix ends and where suffix begins, which changes logits a lot.

Prompt correctness: your string must match what the model was trained on

Google’s official task syntax requires \n at the end of each command, e.g.:

So:

  • If your intent is QA, use: "answer en What is this?\n"
  • If your intent is captioning, use: "caption en\n" (no question text)

Also note: PaliGemma expects image first, then text; reversing order often yields garbage. (Google AI for Developers)

Hugging Face’s processor logic (and derivatives) typically build the string as:

(image_token repeated) + BOS + prompt + "\n" (Hugging Face)

So if you bypass any processor, replicate that exactly.


Two remaining “silent mismatch” zones that often survive your kind of fixes

A) Extra tokens & vocab size (loc/seg tokens)

PaliGemma adds 1024 <loc****> tokens and 128 <seg***> tokens for detection/segmentation-style outputs. If your tokenizer/model disagree about these, you can get:

  • wrong decoding,
  • off-by-N vocab sizing,
  • “valid but wrong” logits near the top-k.

What to do

  • Ensure the tokenizer vocab and embedding/lm_head vocab_size are consistent with the checkpoint config (don’t assume HF defaults).
  • Even if you only do QA/caption, still keep the vocab consistent; the model may produce these tokens in some contexts.

B) Image preprocessing details (resize + RGB + normalization)

SigLIP preprocessing is sensitive. HF’s SiglipImageProcessor defaults include:

  • resize with BICUBIC
  • rescale by 1/255
  • normalize by mean/std (commonly 0.5/0.5 per channel)
  • convert to RGB by default (GitHub)

Your TF pipeline is already doing the x/127.5 - 1 style scaling (equivalent to (x/255 - 0.5)/0.5), which matches common PaliGemma/Keras examples, but explicitly matching resampling + RGB conversion removes another major source of drift. (GitHub)


“Has anyone ported this to TF?” Practical answer

Yes—there is an official-ish, maintained Keras implementation via Keras Hub (PaliGemmaCausalLM), and Google provides an inference tutorial using it. (Google AI for Developers)

This is extremely useful for you because it can serve as:

  • a TF-native reference for preprocessing/prompting,
  • a sanity baseline on Apple silicon (at least at the API level),
  • a way to validate prompt syntax like the required trailing \n. (GitHub)

Good “similar cases / pitfalls” threads worth reading (high signal)

Prompting & newline / command syntax

  • Google’s prompt/system instructions (task syntax + \n requirement + image-first ordering). (Google AI for Developers)
  • KerasHub examples/issues showing prompts like "answer en ...?\n" and "caption en\n". (GitHub)

Attention masking / prefix-LM behavior quirks

  • HF issue on PaliGemma causality/masking details (#30993). (GitHub)
  • HF model discussion on prefix-LM attention mask expectations for SFT. (Hugging Face)

Common integration errors (image token counts, etc.)

  • HF issue about “Number of images does not match number of special image tokens…” (#36008). (GitHub)
  • HF issue about padding/masking order changes (#35855). (GitHub)

Ground-truth design source

  • The PaliGemma technical report (prefix-LM mask figure, SEP=\n, added loc/seg tokens).

One targeted suggestion from your repo state (based on what you shared)

Given you already fixed:

  • cache prefill across layers,
  • vision layer count,
  • config mismatch,

the next highest-leverage checks are:

  1. Prompt is exactly "answer en What is this?\n" (QA) or "caption en\n" (caption)—not mixed. (Google AI for Developers)
  2. RGB + BICUBIC resize exactly matches SigLIP defaults. (GitHub)
  3. Config-driven vocab_size and image_token_index match the checkpoint (257216 / 257152 are common for mix-448). (Hugging Face)
  4. Tokenizer/model agree about the added <loc*> / <seg*> tokens.

If you do those four, and your cache invariants hold per-layer, you’re usually within tiny numeric drift of the reference for first-token logits—and decoding will match.

1 Like