Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from resampler import Resampler | |
| from transformers import CLIPVisionModel | |
| BATCH_SIZE = 2 | |
| OUTPUT_DIM = 1280 | |
| NUM_QUERIES = 8 | |
| NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior) | |
| APPLY_POS_EMB = True # False for no positional embeddings (previous behavior) | |
| IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" | |
| def main(): | |
| image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) | |
| embedding_dim = image_encoder.config.hidden_size | |
| print(f"image_encoder hidden size: ", embedding_dim) | |
| image_proj_model = Resampler( | |
| dim=1024, | |
| depth=2, | |
| dim_head=64, | |
| heads=16, | |
| num_queries=NUM_QUERIES, | |
| embedding_dim=embedding_dim, | |
| output_dim=OUTPUT_DIM, | |
| ff_mult=2, | |
| max_seq_len=257, | |
| apply_pos_emb=APPLY_POS_EMB, | |
| num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, | |
| ) | |
| dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) | |
| with torch.no_grad(): | |
| image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] | |
| print("image_embds shape: ", image_embeds.shape) | |
| with torch.no_grad(): | |
| ip_tokens = image_proj_model(image_embeds) | |
| print("ip_tokens shape:", ip_tokens.shape) | |
| assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) | |
| if __name__ == "__main__": | |
| main() | |