| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Dict, List, Literal, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms |
| | from einops import rearrange |
| |
|
| | from .siglip_vit import create_siglip_vit |
| |
|
| |
|
| | class CLIPVisionTower(nn.Module): |
| | def __init__( |
| | self, |
| | model_name: str = "siglip_large_patch16_384", |
| | image_size: Union[Tuple[int, int], int] = 336, |
| | select_feature: str = "patch", |
| | select_layer: int = -2, |
| | select_layers: list = None, |
| | ckpt_path: str = "", |
| | pixel_mean: Optional[List[float]] = None, |
| | pixel_std: Optional[List[float]] = None, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| |
|
| | self.model_name = model_name |
| | self.select_feature = select_feature |
| | self.select_layer = select_layer |
| | self.select_layers = select_layers |
| |
|
| | vision_tower_params = { |
| | "model_name": model_name, |
| | "image_size": image_size, |
| | "ckpt_path": ckpt_path, |
| | "select_layer": select_layer, |
| | } |
| | vision_tower_params.update(kwargs) |
| | self.vision_tower, self.forward_kwargs = self.build_vision_tower( |
| | vision_tower_params |
| | ) |
| |
|
| | if pixel_mean is not None and pixel_std is not None: |
| | image_norm = torchvision.transforms.Normalize( |
| | mean=pixel_mean, std=pixel_std |
| | ) |
| | else: |
| | image_norm = None |
| |
|
| | self.image_norm = image_norm |
| |
|
| | def build_vision_tower(self, vision_tower_params): |
| | if self.model_name.startswith("siglip"): |
| | self.select_feature = "same" |
| | vision_tower = create_siglip_vit(**vision_tower_params) |
| | forward_kwargs = dict() |
| |
|
| | elif self.model_name.startswith("sam"): |
| | vision_tower = create_sam_vit(**vision_tower_params) |
| | forward_kwargs = dict() |
| |
|
| | else: |
| | from transformers import CLIPVisionModel |
| |
|
| | vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) |
| | forward_kwargs = dict(output_hidden_states=True) |
| |
|
| | return vision_tower, forward_kwargs |
| |
|
| | def feature_select(self, image_forward_outs): |
| | if isinstance(image_forward_outs, torch.Tensor): |
| | |
| | image_features = image_forward_outs |
| | else: |
| | image_features = image_forward_outs.hidden_states[self.select_layer] |
| |
|
| | if self.select_feature == "patch": |
| | |
| | image_features = image_features[:, 1:] |
| | elif self.select_feature == "cls_patch": |
| | image_features = image_features |
| | elif self.select_feature == "same": |
| | image_features = image_features |
| |
|
| | else: |
| | raise ValueError(f"Unexpected select feature: {self.select_feature}") |
| | return image_features |
| |
|
| | def forward(self, images): |
| | """ |
| | |
| | Args: |
| | images (torch.Tensor): [b, 3, H, W] |
| | |
| | Returns: |
| | image_features (torch.Tensor): [b, n_patch, d] |
| | """ |
| |
|
| | if self.image_norm is not None: |
| | images = self.image_norm(images) |
| |
|
| | image_forward_outs = self.vision_tower(images, **self.forward_kwargs) |
| | image_features = self.feature_select(image_forward_outs) |
| | return image_features |
| |
|