Spaces:
Build error
Build error
| import torch | |
| from PIL import Image | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForConditionalGeneration | |
| class ImageCaptioner: | |
| def __init__(self, model_name="blip2-opt", device="cpu"): | |
| self.model_name = model_name | |
| self.device = device | |
| self.processor, self.model = self.initialize_model() | |
| def initialize_model(self): | |
| if self.device == 'cpu': | |
| self.data_type = torch.float32 | |
| else: | |
| self.data_type = torch.float16 | |
| processor, model = None, None | |
| if self.model_name == "blip2-opt": | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b-coco") | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b-coco", torch_dtype=self.data_type, low_cpu_mem_usage=True) | |
| elif self.model_name == "blip2-flan-t5": | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-flan-t5-xl", torch_dtype=self.data_type, low_cpu_mem_usage=True) | |
| # for gpu with small memory | |
| elif self.model_name == "blip": | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| else: | |
| raise NotImplementedError(f"{self.model_name} not implemented.") | |
| model.to(self.device) | |
| if self.device != 'cpu': | |
| model.half() | |
| return processor, model | |
| def image_caption(self, image): | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type) | |
| generated_ids = self.model.generate(**inputs) | |
| generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| return generated_text | |
| def image_caption_debug(self, image_src): | |
| return "A dish with salmon, broccoli, and something yellow." | |