Spaces:
Build error
Build error
| import os | |
| import argparse | |
| import pdb | |
| import time | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| from caption_anything.captioner import build_captioner, BaseCaptioner | |
| from caption_anything.segmenter import build_segmenter | |
| from caption_anything.text_refiner import build_text_refiner | |
| class CaptionAnything: | |
| def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None): | |
| self.args = args | |
| self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner | |
| self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter | |
| self.text_refiner = None | |
| if not args.disable_gpt: | |
| if text_refiner is not None: | |
| self.text_refiner = text_refiner | |
| else: | |
| self.init_refiner(api_key) | |
| def image_embedding(self): | |
| return self.segmenter.image_embedding | |
| def image_embedding(self, image_embedding): | |
| self.segmenter.image_embedding = image_embedding | |
| def original_size(self): | |
| return self.segmenter.predictor.original_size | |
| def original_size(self, original_size): | |
| self.segmenter.predictor.original_size = original_size | |
| def input_size(self): | |
| return self.segmenter.predictor.input_size | |
| def input_size(self, input_size): | |
| self.segmenter.predictor.input_size = input_size | |
| def setup(self, image_embedding, original_size, input_size, is_image_set): | |
| self.image_embedding = image_embedding | |
| self.original_size = original_size | |
| self.input_size = input_size | |
| self.segmenter.predictor.is_image_set = is_image_set | |
| def init_refiner(self, api_key): | |
| try: | |
| self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key) | |
| self.text_refiner.llm('hi') # test | |
| except: | |
| self.text_refiner = None | |
| print('OpenAI GPT is not available') | |
| def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False): | |
| # TODO: Add support to multiple seg masks. | |
| # segment with prompt | |
| print("CA prompt: ", prompt, "CA controls", controls) | |
| seg_mask = self.segmenter.inference(image, prompt)[0, ...] | |
| if self.args.enable_morphologyex: | |
| seg_mask = 255 * seg_mask.astype(np.uint8) | |
| seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1) | |
| seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8)) | |
| seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8)) | |
| seg_mask = seg_mask[:, :, 0] > 0 | |
| mask_save_path = f'result/mask_{time.time()}.png' | |
| if not os.path.exists(os.path.dirname(mask_save_path)): | |
| os.makedirs(os.path.dirname(mask_save_path)) | |
| seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.) | |
| if seg_mask_img.mode != 'RGB': | |
| seg_mask_img = seg_mask_img.convert('RGB') | |
| seg_mask_img.save(mask_save_path) | |
| print('seg_mask path: ', mask_save_path) | |
| print("seg_mask.shape: ", seg_mask.shape) | |
| # captioning with mask | |
| if self.args.enable_reduce_tokens: | |
| caption, crop_save_path = self.captioner. \ | |
| inference_with_reduced_tokens(image, seg_mask, | |
| crop_mode=self.args.seg_crop_mode, | |
| filter=self.args.clip_filter, | |
| disable_regular_box=self.args.disable_regular_box) | |
| else: | |
| caption, crop_save_path = self.captioner. \ | |
| inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, | |
| filter=self.args.clip_filter, | |
| disable_regular_box=self.args.disable_regular_box) | |
| # refining with TextRefiner | |
| context_captions = [] | |
| if self.args.context_captions: | |
| context_captions.append(self.captioner.inference(image)) | |
| if not disable_gpt and self.text_refiner is not None: | |
| refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, | |
| enable_wiki=enable_wiki) | |
| else: | |
| refined_caption = {'raw_caption': caption} | |
| out = {'generated_captions': refined_caption, | |
| 'crop_save_path': crop_save_path, | |
| 'mask_save_path': mask_save_path, | |
| 'mask': seg_mask_img, | |
| 'context_captions': context_captions} | |
| return out | |
| if __name__ == "__main__": | |
| from caption_anything.utils.parser import parse_augment | |
| args = parse_augment() | |
| # image_path = 'test_images/img3.jpg' | |
| image_path = 'test_images/img1.jpg' | |
| prompts = [ | |
| { | |
| "prompt_type": ["click"], | |
| "input_point": [[500, 300], [200, 500]], | |
| "input_label": [1, 0], | |
| "multimask_output": "True", | |
| }, | |
| { | |
| "prompt_type": ["click"], | |
| "input_point": [[300, 800]], | |
| "input_label": [1], | |
| "multimask_output": "True", | |
| } | |
| ] | |
| controls = { | |
| "length": "30", | |
| "sentiment": "positive", | |
| # "imagination": "True", | |
| "imagination": "False", | |
| "language": "English", | |
| } | |
| model = CaptionAnything(args, os.environ['OPENAI_API_KEY']) | |
| for prompt in prompts: | |
| print('*' * 30) | |
| print('Image path: ', image_path) | |
| image = Image.open(image_path) | |
| print(image) | |
| print('Visual controls (SAM prompt):\n', prompt) | |
| print('Language controls:\n', controls) | |
| out = model.inference(image_path, prompt, controls) | |