Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import os | |
| import numpy as np | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import clip | |
| import pytorch_lightning as pl | |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
| from PIL import Image | |
| from timm.models.vision_transformer import resize_pos_embed | |
| from cog import BasePredictor, Path, Input | |
| import captioning.utils.opts as opts | |
| import captioning.models as models | |
| import captioning.utils.misc as utils | |
| class Predictor(BasePredictor): | |
| def setup(self): | |
| import __main__ | |
| __main__.ModelCheckpoint = pl.callbacks.ModelCheckpoint | |
| self.device = torch.device("cuda:0") | |
| self.dict_json = json.load(open("./data/cocotalk.json")) | |
| self.ix_to_word = self.dict_json["ix_to_word"] | |
| self.vocab_size = len(self.ix_to_word) | |
| self.clip_model, self.clip_transform = clip.load( | |
| "RN50", jit=False, device=self.device | |
| ) | |
| self.preprocess = Compose( | |
| [ | |
| Resize((448, 448), interpolation=Image.BICUBIC), | |
| CenterCrop((448, 448)), | |
| ToTensor(), | |
| ] | |
| ) | |
| def predict( | |
| self, | |
| image: Path = Input( | |
| description="Input image.", | |
| ), | |
| reward: str = Input( | |
| choices=["mle", "cider", "clips", "cider_clips", "clips_grammar"], | |
| default="clips_grammar", | |
| description="Choose a reward criterion.", | |
| ), | |
| ) -> str: | |
| self.device = torch.device("cuda:0") | |
| self.dict_json = json.load(open("./data/cocotalk.json")) | |
| self.ix_to_word = self.dict_json["ix_to_word"] | |
| self.vocab_size = len(self.ix_to_word) | |
| self.clip_model, self.clip_transform = clip.load( | |
| "RN50", jit=False, device=self.device | |
| ) | |
| self.preprocess = Compose( | |
| [ | |
| Resize((448, 448), interpolation=Image.BICUBIC), | |
| CenterCrop((448, 448)), | |
| ToTensor(), | |
| ] | |
| ) | |
| cfg = ( | |
| f"configs/phase1/clipRN50_{reward}.yml" | |
| if reward == "mle" | |
| else f"configs/phase2/clipRN50_{reward}.yml" | |
| ) | |
| print("Loading cfg from", cfg) | |
| opt = opts.parse_opt(parse=False, cfg=cfg) | |
| print("vocab size:", self.vocab_size) | |
| seq_length = 1 | |
| opt.vocab_size = self.vocab_size | |
| opt.seq_length = seq_length | |
| opt.batch_size = 1 | |
| opt.vocab = self.ix_to_word | |
| print(opt.caption_model) | |
| model = models.setup(opt) | |
| del opt.vocab | |
| ckpt_path = opt.checkpoint_path + "-last.ckpt" | |
| print("Loading checkpoint from", ckpt_path) | |
| raw_state_dict = torch.load(ckpt_path, map_location=self.device) | |
| strict = True | |
| state_dict = raw_state_dict["state_dict"] | |
| if "_vocab" in state_dict: | |
| model.vocab = utils.deserialize(state_dict["_vocab"]) | |
| del state_dict["_vocab"] | |
| elif strict: | |
| raise KeyError | |
| if "_opt" in state_dict: | |
| saved_model_opt = utils.deserialize(state_dict["_opt"]) | |
| del state_dict["_opt"] | |
| # Make sure the saved opt is compatible with the curren topt | |
| need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] | |
| for checkme in need_be_same: | |
| if ( | |
| getattr(saved_model_opt, checkme) | |
| in [ | |
| "updown", | |
| "topdown", | |
| ] | |
| and getattr(opt, checkme) in ["updown", "topdown"] | |
| ): | |
| continue | |
| assert getattr(saved_model_opt, checkme) == getattr(opt, checkme), ( | |
| "Command line argument and saved model disagree on '%s' " % checkme | |
| ) | |
| elif strict: | |
| raise KeyError | |
| res = model.load_state_dict(state_dict, strict) | |
| print(res) | |
| model = model.to(self.device) | |
| model.eval() | |
| image_mean = ( | |
| torch.Tensor([0.48145466, 0.4578275, 0.40821073]) | |
| .to(self.device) | |
| .reshape(3, 1, 1) | |
| ) | |
| image_std = ( | |
| torch.Tensor([0.26862954, 0.26130258, 0.27577711]) | |
| .to(self.device) | |
| .reshape(3, 1, 1) | |
| ) | |
| num_patches = 196 # 600 * 1000 // 32 // 32 | |
| pos_embed = nn.Parameter( | |
| torch.zeros( | |
| 1, | |
| num_patches + 1, | |
| self.clip_model.visual.attnpool.positional_embedding.shape[-1], | |
| device=self.device, | |
| ), | |
| ) | |
| pos_embed.weight = resize_pos_embed( | |
| self.clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed | |
| ) | |
| self.clip_model.visual.attnpool.positional_embedding = pos_embed | |
| with torch.no_grad(): | |
| image = self.preprocess(Image.open(str(image)).convert("RGB")) | |
| image = torch.tensor(np.stack([image])).to(self.device) | |
| image -= image_mean | |
| image /= image_std | |
| tmp_att, tmp_fc = self.clip_model.encode_image(image) | |
| tmp_att = tmp_att[0].permute(1, 2, 0) | |
| att_feat = tmp_att | |
| # Inference configurations | |
| eval_kwargs = {} | |
| eval_kwargs.update(vars(opt)) | |
| with torch.no_grad(): | |
| fc_feats = torch.zeros((1, 0)).to(self.device) | |
| att_feats = att_feat.view(1, 196, 2048).float().to(self.device) | |
| att_masks = None | |
| # forward the model to also get generated samples for each image | |
| # Only leave one feature for each image, in case duplicate sample | |
| tmp_eval_kwargs = eval_kwargs.copy() | |
| tmp_eval_kwargs.update({"sample_n": 1}) | |
| seq, seq_logprobs = model( | |
| fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode="sample" | |
| ) | |
| seq = seq.data | |
| sents = utils.decode_sequence(model.vocab, seq) | |
| return sents[0] | |