Spaces:
Configuration error
Configuration error
| import logging | |
| import random | |
| import torch | |
| from torch.cuda.amp import autocast as autocast | |
| import torch.nn as nn | |
| from minigpt4.common.registry import registry | |
| from minigpt4.models.blip2 import Blip2Base, disabled_train | |
| from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM | |
| from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub | |
| from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig | |
| from peft import ( | |
| LoraConfig, | |
| get_peft_model, | |
| prepare_model_for_kbit_training | |
| ) | |
| import time | |
| import numpy as np | |
| from minigpt4.models import policies | |
| class MiniGPT4v(Blip2Base): | |
| """ | |
| BLIP2 GPT-LLAMA model. | |
| """ | |
| PRETRAINED_MODEL_CONFIG_DICT = { | |
| "pretrain_vicuna": "configs/models/minigpt4.yaml", | |
| } | |
| def __init__( | |
| self, | |
| vit_model="eva_clip_g", | |
| img_size=224, | |
| drop_path_rate=0, | |
| use_grad_checkpoint=False, | |
| vit_precision="fp16", | |
| freeze_vit=True, | |
| llama_model="", | |
| prompt_path="", | |
| prompt_template="", | |
| max_txt_len=32, | |
| low_resource=False, # use 8 bit and put vit in cpu | |
| end_sym='\n', | |
| lora_r = 8, | |
| lora_target_modules = ["q_proj","v_proj"], | |
| lora_alpha=16, | |
| # lora_r = 16, | |
| # lora_target_modules = ["q_proj","v_proj","v_proj"], | |
| lora_dropout= 0.05, | |
| ckpt_path = "", | |
| system_prompt= False, | |
| chat_template=False, | |
| token_pooling=True, | |
| use_grad_checkpoint_llm=False, | |
| max_context_len=3800, | |
| remove_template = False, | |
| ): | |
| super().__init__() | |
| self.tokenizer = self.init_tokenizer() | |
| self.low_resource = low_resource | |
| self.token_pooling = token_pooling | |
| self.remove_template = remove_template | |
| print("token pooling", self.token_pooling) | |
| self.use_grad_checkpoint_llm = use_grad_checkpoint_llm | |
| self.max_context_len = max_context_len | |
| self.chat_template = chat_template | |
| # print('Loading VIT') | |
| # self.visual_encoder, self.ln_vision = self.init_vision_encoder( | |
| # vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision | |
| # ) | |
| print("vit precision", vit_precision) | |
| self.visual_encoder, self.ln_vision = self.init_vision_encoder( | |
| vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision | |
| ) | |
| for name, param in self.visual_encoder.named_parameters(): | |
| param.requires_grad = False | |
| self.visual_encoder = self.visual_encoder.eval() | |
| self.visual_encoder.train = disabled_train | |
| for name, param in self.ln_vision.named_parameters(): | |
| param.requires_grad = False | |
| self.ln_vision = self.ln_vision.eval() | |
| self.ln_vision.train = disabled_train | |
| logging.info("freeze vision encoder") | |
| print("freeze the vision encoder") | |
| print('Loading VIT Done') | |
| # print("visual encoder shape", self.visual_encoder.pos_embed.shape) | |
| # assert False | |
| print('Loading LLAMA') | |
| self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
| if 'CodeLlama' in llama_model: | |
| self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) # | |
| self.llama_tokenizer.pad_token = "$$" | |
| else: | |
| self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) # | |
| self.llama_tokenizer.pad_token = "$$" | |
| self.system_prompt = system_prompt | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| self.llama_model = LlamaForCausalLM.from_pretrained( | |
| llama_model, | |
| quantization_config=bnb_config, | |
| device_map={"": 0} | |
| ) | |
| # self.llama_model.gradient_checkpointing_enable() | |
| self.llama_model = prepare_model_for_kbit_training(self.llama_model) | |
| # self.llama_model.print_trainable_parameters() | |
| print('Loading LLAMA Done') | |
| self.merge_n = 3 | |
| self.llama_proj = nn.Linear( | |
| 1408 * self.merge_n**2, self.llama_model.config.hidden_size | |
| ) | |
| self.max_txt_len = max_txt_len | |
| self.end_sym = end_sym | |
| if prompt_path: | |
| with open(prompt_path, 'r') as f: | |
| raw_prompts = f.read().splitlines() | |
| filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt] | |
| self.prompt_list = [prompt_template.format(p) for p in filted_prompts] | |
| print('Load {} training prompts'.format(len(self.prompt_list))) | |
| print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) | |
| else: | |
| self.prompt_list = [] | |
| def encode_img(self, image): | |
| device = image.device | |
| if len(image.shape) > 4: | |
| image = image.reshape(-1, *image.shape[-3:]) | |
| bs, ch, w, h = image.shape | |
| assert w % 224 == 0 | |
| bw = w // 224 | |
| assert h % 224 == 0 | |
| bh = h // 224 | |
| image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224 | |
| image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224) | |
| with self.maybe_autocast(): | |
| image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device) | |
| image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1]) | |
| image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs | |
| image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1]) | |
| bs, pn, hs = image_embeds.shape | |
| image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2)) | |
| inputs_llama = self.llama_proj(image_embeds) | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) | |
| return inputs_llama, atts_llama | |
| def get_context_emb(self, prompt, img_list): | |
| img_device = img_list[0].device | |
| prompt_segs = prompt.split('<ImageHere>') | |
| assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." | |
| seg_tokens = [ | |
| self.llama_tokenizer( | |
| seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg | |
| for i, seg in enumerate(prompt_segs) | |
| ] | |
| seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] | |
| mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] | |
| mixed_embs = torch.cat(mixed_embs, dim=1) | |
| return mixed_embs | |
| def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): | |
| if prompts is None or len(prompts) == 0: | |
| # prompts is not provided, just return the original image embedding | |
| return img_embeds, atts_img | |
| elif img_embeds is None: | |
| # prompt is provided but there is no image embedding. return the prompt embedding in right padding | |
| self.llama_tokenizer.padding_side = "right" | |
| prompt_tokens = self.llama_tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding="longest", | |
| add_special_tokens=False | |
| ).to(self.device) | |
| prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) | |
| atts_prompt = prompt_tokens.attention_mask | |
| return prompt_embeds, atts_prompt | |
| else: | |
| # return the multi-modal embedding in right padding | |
| emb_lists = [] | |
| for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): | |
| pn = each_img_embed.shape[-2] | |
| if lengths is not None: | |
| each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) | |
| each_img_embed = each_img_embed[:lengths[idx] * pn] | |
| p_segs = each_prompt.split('<ImageHere>') | |
| interleave_emb = [] | |
| for idx, seg in enumerate(p_segs[:-1]): | |
| p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) | |
| p_embed = self.embed_tokens(p_tokens.input_ids) | |
| interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1)) | |
| wrapped_emb = torch.cat(interleave_emb, dim=1) | |
| p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) | |
| p_embed = self.embed_tokens(p_tokens.input_ids) | |
| wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1) | |
| emb_lists.append(wrapped_emb) | |
| emb_lens = [emb.shape[1] for emb in emb_lists] | |
| pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) | |
| max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len | |
| wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() | |
| wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) | |
| for i, emb in enumerate(emb_lists): | |
| length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len | |
| wrapped_embs[i, :length] = emb[:, :length] | |
| wrapped_atts[i, :length] = 1 | |
| return wrapped_embs, wrapped_atts | |
| def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): | |
| """ | |
| Concatenate the batched input embedding and batched output embedding together. | |
| Both the input and the output embedding should be right padded. | |
| """ | |
| input_lens = [] | |
| cat_embs = [] | |
| cat_atts = [] | |
| for i in range(input_embs.size(0)): | |
| input_len = input_atts[i].sum() | |
| input_lens.append(input_len) | |
| cat_embs.append( | |
| torch.cat([ | |
| input_embs[i][:input_len], | |
| output_embs[i], | |
| input_embs[i][input_len:] | |
| ]) | |
| ) | |
| cat_atts.append( | |
| torch.cat([ | |
| input_atts[i][:input_len], | |
| output_atts[i], | |
| input_atts[i][input_len:] | |
| ]) | |
| ) | |
| # print('===================================') | |
| # print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones]) | |
| # print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2]) | |
| # print('check out emb: ', output_embs[i][:2]) | |
| # print('check out pad emb: ', output_embs[i][-2:]) | |
| # print('+++++++++++++++++++++++++++++++++++') | |
| # | |
| # print('check attn before: ', input_atts[i][:this_input_ones]) | |
| # print('check attn after: ', input_atts[i][this_input_ones:]) | |
| # print('check attn gt before: ', output_atts[i][:3]) | |
| # print('check attn gt after: ', output_atts[i][-3:]) | |
| cat_embs = torch.stack(cat_embs) | |
| cat_atts = torch.stack(cat_atts) | |
| return cat_embs, cat_atts, input_lens | |
| def get_conv_emb(self, conv_q, conv_a, conv_img): | |
| """concatenate conversation and make sure the model is only trained to regress the answer""" | |
| regress_embs_list = [] | |
| targets_list = [] | |
| batch_size = len(conv_q) | |
| for batch_idx in range(batch_size): | |
| questions, answers = conv_q[batch_idx], conv_a[batch_idx] | |
| assigned_imgs = conv_img[batch_idx] | |
| questions = [self.prompt_wrap( | |
| img_embeds=img, | |
| atts_img=None, | |
| prompts=[q], | |
| lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)] | |
| q_embs = [emb for emb, _ in questions] | |
| answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers] | |
| cur_emb = [] | |
| cur_target = [] | |
| for i in range(len(questions)): | |
| cur_emb.append(q_embs[i]) | |
| cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100) | |
| cur_emb.append(self.embed_tokens(answers[i].input_ids)) | |
| cur_target.append(answers[i].input_ids) | |
| cur_emb = torch.cat(cur_emb, dim=1) | |
| cur_target = torch.cat(cur_target, dim=1) | |
| regress_embs_list.append(cur_emb) | |
| targets_list.append(cur_target) | |
| max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) | |
| regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device) | |
| regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device) | |
| targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100 | |
| for batch_idx in range(batch_size): | |
| cur_len = regress_embs_list[batch_idx].shape[1] | |
| regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len] | |
| regress_attn[batch_idx, :cur_len] = 1 | |
| targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] | |
| return regress_embeds, regress_attn, targets | |
| def preparing_embedding(self, samples): | |
| def remove_special_tokens(data): | |
| # if "instruction_input" in data: | |
| data = [instruct.replace(" [caption]","") for instruct in data] | |
| data = [instruct.replace(" [vqa]","") for instruct in data] | |
| data = [instruct.replace(" [grounding]","") for instruct in data] | |
| data = [instruct.replace(" [identify]","") for instruct in data] | |
| data = [instruct.replace(" [refer]","") for instruct in data] | |
| return data | |
| ### prepare input tokens | |
| if 'image' in samples: | |
| img_embeds, img_atts = self.encode_img(samples["image"]) | |
| else: | |
| img_embeds = img_atts = None | |
| if 'conv_q' in samples: | |
| # handeling conversation datasets | |
| conv_q, conv_a = samples['conv_q'], samples['conv_a'] | |
| connect_sym = samples['connect_sym'][0] | |
| conv_q = [q.split(connect_sym)for q in conv_q] | |
| conv_a = [a.split(connect_sym) for a in conv_a] | |
| conv_img = assign_imgs(conv_q, img_embeds) | |
| if self.chat_template: | |
| conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q] | |
| regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img) | |
| cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0] | |
| else: | |
| instruction = samples["instruction_input"] if "instruction_input" in samples else None | |
| # print("instruction before", instruction) | |
| if self.remove_template: | |
| instruction = remove_special_tokens(instruction) | |
| # print("instruction after", instruction) | |
| if self.chat_template: | |
| instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction] | |
| if 'length' in samples: | |
| # the input is a image train (like videos) | |
| bsz, pn, hs = img_embeds.shape | |
| img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) | |
| cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) | |
| else: | |
| cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction) | |
| ### prepare target tokens | |
| self.llama_tokenizer.padding_side = "right" | |
| text = [t + self.end_sym for t in samples["answer"]] | |
| regress_tokens = self.llama_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="longest", | |
| truncation=True, | |
| max_length=self.max_txt_len, | |
| add_special_tokens=False | |
| ).to(self.device) | |
| regress_token_ids = regress_tokens.input_ids | |
| regress_atts = regress_tokens.attention_mask | |
| part_targets = regress_token_ids.masked_fill( | |
| regress_token_ids == self.llama_tokenizer.pad_token_id, -100 | |
| ) | |
| regress_embeds = self.embed_tokens(regress_token_ids) | |
| return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets | |
| def forward(self, samples, reduction="mean"): | |
| # prepare the embedding to condition and the embedding to regress | |
| cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \ | |
| self.preparing_embedding(samples) | |
| # concat the embedding to condition and the embedding to regress | |
| inputs_embeds, attention_mask, input_lens = \ | |
| self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) | |
| # get bos token embedding | |
| bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id | |
| bos_embeds = self.embed_tokens(bos) | |
| bos_atts = attention_mask[:, :1] | |
| # add bos token at the begining | |
| inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) | |
| attention_mask = torch.cat([bos_atts, attention_mask], dim=1) | |
| # ensemble the final targets | |
| targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], | |
| dtype=torch.long).to(self.device).fill_(-100) | |
| for i, target in enumerate(part_targets): | |
| targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos | |
| with self.maybe_autocast(): | |
| outputs = self.llama_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| labels=targets, | |
| reduction=reduction | |
| ) | |
| loss = outputs.loss | |
| return {"loss": loss} | |
| def generate( | |
| self, | |
| images, | |
| texts, | |
| use_nucleus_sampling=False, | |
| num_beams=1, | |
| max_new_tokens=20, | |
| min_length=1, | |
| top_p=0.9, | |
| repetition_penalty=1, | |
| length_penalty=1, | |
| temperature=1, | |
| do_sample=False, | |
| stop_words_ids=[2], | |
| lengths=None, | |
| ): | |
| ''' | |
| function for generate test use | |
| ''' | |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( | |
| stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) | |
| img_embeds, atts_img = self.encode_img(images.to(self.device)) | |
| if lengths is not None: | |
| image_lists = [] | |
| img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1]) | |
| for idx, img_embed in enumerate(img_embeds): | |
| image_lists.append([img_embed[i][None] for i in range(lengths[idx])]) | |
| else: | |
| image_lists = [[image_emb[None]] for image_emb in img_embeds] | |
| assert len(texts) == len(image_lists) | |
| batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] | |
| batch_size = len(batch_embs) | |
| max_len = max([emb.shape[1] for emb in batch_embs]) | |
| emb_dim = batch_embs[0].shape[2] | |
| dtype = batch_embs[0].dtype | |
| device = batch_embs[0].device | |
| embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) | |
| attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) | |
| for i, emb in enumerate(batch_embs): | |
| emb_len = emb.shape[1] | |
| embs[i, -emb_len:] = emb[0] | |
| attn_mask[i, -emb_len:] = 1 | |
| with self.maybe_autocast(): | |
| outputs = self.llama_model.generate( | |
| inputs_embeds=embs, | |
| attention_mask=attn_mask, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| do_sample=do_sample, | |
| # stopping_criteria=stopping_criteria, | |
| ) | |
| answers = [] | |
| for output_token in outputs: | |
| if output_token[0] == 0: | |
| output_token = output_token[1:] | |
| output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) | |
| output_texts = output_texts.split('</s>')[0] # remove the stop sign </s> | |
| output_texts = output_texts.replace("<s>", "") | |
| output_texts = output_texts.split(r'[/INST]')[-1].strip() | |
| answers.append(output_texts) | |
| return answers | |
| def multi_select(self, images, texts, answers, num_cand=None): | |
| all_losses = [] | |
| for answer in answers: | |
| choice_samples = { | |
| 'image': images, | |
| 'instruction_input': texts, | |
| 'answer': answer | |
| } | |
| loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) | |
| all_losses.append(loss) | |
| torch.cuda.empty_cache() | |
| all_losses = torch.cat(all_losses, dim=-1) | |
| if num_cand is not None: | |
| for i in range(all_losses.shape[0]): | |
| all_losses[i, num_cand[i]:] = 9999 | |
| output_class_ranks = torch.argsort(all_losses, dim=-1) | |
| return output_class_ranks.tolist() | |
| def predict_answers( | |
| self, | |
| samples, | |
| num_beams=5, | |
| inference_method="generate", | |
| max_len=10, | |
| min_len=1, | |
| num_ans_candidates=128, | |
| answer_list=None, | |
| prompt="", | |
| length_penalty=0, | |
| **kwargs | |
| ): | |
| ''' | |
| function for open-ended VQA | |
| ''' | |
| images = samples["image"].cuda() | |
| texts = samples["instruction_input"] | |
| output_text = self.generate( | |
| images=images, | |
| texts=texts, | |
| num_beams=num_beams, | |
| max_new_tokens=max_len, | |
| min_length=min_len, | |
| length_penalty=length_penalty | |
| ) | |
| if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]: | |
| output_text = self._lemmatize(output_text) | |
| return output_text | |
| def predict_class( | |
| self, | |
| samples, | |
| num_beams=5, | |
| inference_method="generate", | |
| max_len=10, | |
| min_len=1, | |
| num_ans_candidates=5, | |
| answer_list=None, | |
| prompt="", | |
| length_penalty=0, | |
| **kwargs | |
| ): | |
| ''' | |
| function for multi-choice VQA | |
| ''' | |
| image = samples["image"].cuda() | |
| instruction = samples['instruction_input'] | |
| answers = samples["choices"] | |
| num_cand = samples["num_choices"] | |
| ranks = self.multi_select(image, instruction, answers, num_cand) | |
| pred_ans = [] | |
| for i, rank in enumerate(ranks): | |
| pred = answers[rank[0]][i] | |
| pred_ans.append(pred) | |
| return pred_ans | |
| def embed_tokens(self, token_ids): | |
| try: | |
| embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) | |
| except AttributeError: | |
| embeds = self.llama_model.model.embed_tokens(token_ids) | |
| return embeds | |
| def from_config(cls, cfg): | |
| vit_model = cfg.get("vit_model", "eva_clip_g") | |
| q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") | |
| img_size = cfg.get("image_size") | |
| num_query_token = cfg.get("num_query_token") | |
| llama_model = cfg.get("llama_model") | |
| drop_path_rate = cfg.get("drop_path_rate", 0) | |
| use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) | |
| vit_precision = cfg.get("vit_precision", "fp16") | |
| freeze_vit = cfg.get("freeze_vit", True) | |
| freeze_qformer = cfg.get("freeze_qformer", True) | |
| low_resource = cfg.get("low_resource", False) | |
| prompt_path = cfg.get("prompt_path", "") | |
| prompt_template = cfg.get("prompt_template", "") | |
| max_txt_len = cfg.get("max_txt_len", 300) | |
| end_sym = cfg.get("end_sym", '\n') | |
| lora_r = cfg.get("lora_r",64) | |
| lora_alpha = cfg.get("lora_alpha",16) | |
| chat_template = cfg.get("chat_template",False) | |
| system_prompt = cfg.get("system_prompt", False) | |
| token_pooling = cfg.get("token_pooling",True) | |
| use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) | |
| max_context_len = cfg.get("max_context_len", 3800) | |
| remove_template = cfg.get("remove_template", False) | |
| model = cls( | |
| vit_model=vit_model, | |
| img_size=img_size, | |
| drop_path_rate=drop_path_rate, | |
| use_grad_checkpoint=use_grad_checkpoint, | |
| vit_precision=vit_precision, | |
| freeze_vit=freeze_vit, | |
| llama_model=llama_model, | |
| prompt_path=prompt_path, | |
| prompt_template=prompt_template, | |
| max_txt_len=max_txt_len, | |
| low_resource=low_resource, | |
| end_sym=end_sym, | |
| lora_r = lora_r, | |
| lora_alpha = lora_alpha, | |
| chat_template = chat_template, | |
| system_prompt = system_prompt, | |
| token_pooling = token_pooling, | |
| use_grad_checkpoint_llm=use_grad_checkpoint_llm, | |
| max_context_len=max_context_len, | |
| remove_template = remove_template | |
| ) | |
| ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 | |
| if ckpt_path: | |
| print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| msg = model.load_state_dict(ckpt['model'], strict=False) | |
| return model | |
| def assign_imgs(batched_instruct_list, batched_img_embeds): | |
| '''this function is used when the data is interleaved. | |
| the interlevaed data is separated, and this function assign | |
| corresponding image embeddings to each segment''' | |
| if len(batched_img_embeds.shape) == 3: | |
| batched_img_embeds = batched_img_embeds[:, None] | |
| batched_assigned = [] | |
| for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds): | |
| img_idx = 0 | |
| assigned_img = [] | |
| n_assigned = [] | |
| for instruct in instruct_list: | |
| n_img = instruct.count('<ImageHere>') | |
| if n_img > 0: # this instruction include images. | |
| assigned_img.append(img_embeds[None, img_idx:img_idx+n_img]) | |
| img_idx += n_img | |
| n_assigned.append(n_img) | |
| else: # this instruction doesn't include images | |
| assigned_img.append(None) | |
| n_assigned.append(None) | |
| batched_assigned.append(assigned_img) | |
| return batched_assigned |