Spaces:
Running
on
Zero
Running
on
Zero
| # Prepare the models to speed up loading them later | |
| import torch | |
| from torch import nn, Tensor | |
| import os | |
| from tqdm import tqdm | |
| import json | |
| from .utils import load | |
| model_name_map = { | |
| "ViT-B/16": "vit_b_16", | |
| "ViT-L/14": "vit_l_14", | |
| } | |
| class CLIPTextEncoderTemp(nn.Module): | |
| def __init__( | |
| self, | |
| clip: nn.Module, | |
| ) -> None: | |
| super().__init__() | |
| self.context_length = clip.context_length | |
| self.vocab_size = clip.vocab_size | |
| self.dtype = clip.dtype | |
| self.token_embedding = clip.token_embedding | |
| self.positional_embedding = clip.positional_embedding | |
| self.transformer = clip.transformer | |
| self.ln_final = clip.ln_final | |
| self.text_projection = clip.text_projection | |
| def forward(self, text: Tensor) -> None: | |
| pass | |
| def prepare() -> None: | |
| print("Preparing CLIP models...") | |
| curr_dir = os.path.dirname(os.path.abspath(__file__)) | |
| weight_dir = os.path.join(curr_dir, "weights") | |
| config_dir = os.path.join(curr_dir, "configs") | |
| os.makedirs(weight_dir, exist_ok=True) | |
| os.makedirs(config_dir, exist_ok=True) | |
| device = torch.device("cpu") | |
| for model_name in tqdm(["ViT-B/16", "ViT-L/14"]): | |
| model = load(model_name, device=device).to(device) | |
| image_encoder = model.visual.to(device) | |
| text_encoder = CLIPTextEncoderTemp(model).to(device) | |
| torch.save(model.state_dict(), os.path.join(weight_dir, f"clip_{model_name_map[model_name]}.pth")) | |
| torch.save(image_encoder.state_dict(), os.path.join(weight_dir, f"clip_image_encoder_{model_name_map[model_name]}.pth")) | |
| torch.save(text_encoder.state_dict(), os.path.join(weight_dir, f"clip_text_encoder_{model_name_map[model_name]}.pth")) | |
| model_config = { | |
| "embed_dim": model.embed_dim, | |
| # vision | |
| "image_resolution": model.image_resolution, | |
| "vision_layers": model.vision_layers, | |
| "vision_width": model.vision_width, | |
| "vision_patch_size": model.vision_patch_size, | |
| # text | |
| "context_length": model.context_length, | |
| "vocab_size": model.vocab_size, | |
| "transformer_width": model.transformer_width, | |
| "transformer_heads": model.transformer_heads, | |
| "transformer_layers": model.transformer_layers, | |
| } | |
| image_encoder_config = { | |
| "embed_dim": model.embed_dim, | |
| "image_resolution": model.image_resolution, | |
| "vision_layers": model.vision_layers, | |
| "vision_width": model.vision_width, | |
| "vision_patch_size": model.vision_patch_size, | |
| "vision_heads": model.vision_heads, | |
| } | |
| text_encoder_config = { | |
| "embed_dim": model.embed_dim, | |
| "context_length": model.context_length, | |
| "vocab_size": model.vocab_size, | |
| "transformer_width": model.transformer_width, | |
| "transformer_heads": model.transformer_heads, | |
| "transformer_layers": model.transformer_layers, | |
| } | |
| with open(os.path.join(config_dir, f"clip_{model_name_map[model_name]}.json"), "w") as f: | |
| json.dump(model_config, f, indent=4) | |
| with open(os.path.join(config_dir, f"clip_image_encoder_{model_name_map[model_name]}.json"), "w") as f: | |
| json.dump(image_encoder_config, f, indent=4) | |
| with open(os.path.join(config_dir, f"clip_text_encoder_{model_name_map[model_name]}.json"), "w") as f: | |
| json.dump(text_encoder_config, f, indent=4) | |
| print("Done!") | |