Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		ttengwang
		
	commited on
		
		
					Commit 
							
							·
						
						108f2df
	
1
								Parent(s):
							
							b7e072a
								
share ocr_reader to accelerate inferenec
Browse files- app.py +11 -3
- caption_anything/captioner/blip2.py +2 -2
- caption_anything/model.py +8 -5
    	
        app.py
    CHANGED
    
    | @@ -17,7 +17,7 @@ from caption_anything.text_refiner import build_text_refiner | |
| 17 | 
             
            from caption_anything.segmenter import build_segmenter
         | 
| 18 | 
             
            from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
         | 
| 19 | 
             
            from segment_anything import sam_model_registry
         | 
| 20 | 
            -
             | 
| 21 |  | 
| 22 | 
             
            args = parse_augment()
         | 
| 23 | 
             
            args.segmenter = "huge"
         | 
| @@ -30,6 +30,8 @@ else: | |
| 30 |  | 
| 31 | 
             
            shared_captioner = build_captioner(args.captioner, args.device, args)
         | 
| 32 | 
             
            shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
         | 
|  | |
|  | |
| 33 | 
             
            tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
         | 
| 34 | 
             
            shared_chatbot_tools = build_chatbot_tools(tools_dict)
         | 
| 35 |  | 
| @@ -57,13 +59,13 @@ class ImageSketcher(gr.Image): | |
| 57 | 
             
                    return super().preprocess(x)
         | 
| 58 |  | 
| 59 |  | 
| 60 | 
            -
            def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
         | 
| 61 | 
             
                                                   session_id=None):
         | 
| 62 | 
             
                segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
         | 
| 63 | 
             
                captioner = captioner
         | 
| 64 | 
             
                if session_id is not None:
         | 
| 65 | 
             
                    print('Init caption anything for session {}'.format(session_id))
         | 
| 66 | 
            -
                return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
         | 
| 67 |  | 
| 68 |  | 
| 69 | 
             
            def init_openai_api_key(api_key=""):
         | 
| @@ -146,6 +148,7 @@ def upload_callback(image_input, state, visual_chatgpt=None): | |
| 146 | 
             
                    api_key="",
         | 
| 147 | 
             
                    captioner=shared_captioner,
         | 
| 148 | 
             
                    sam_model=shared_sam_model,
         | 
|  | |
| 149 | 
             
                    session_id=iface.app_id
         | 
| 150 | 
             
                )
         | 
| 151 | 
             
                model.segmenter.set_image(image_input)
         | 
| @@ -154,6 +157,7 @@ def upload_callback(image_input, state, visual_chatgpt=None): | |
| 154 | 
             
                input_size = model.input_size
         | 
| 155 |  | 
| 156 | 
             
                if visual_chatgpt is not None:
         | 
|  | |
| 157 | 
             
                    new_image_path = get_new_image_name('chat_image', func_name='upload')
         | 
| 158 | 
             
                    image_input.save(new_image_path)
         | 
| 159 | 
             
                    visual_chatgpt.current_image = new_image_path
         | 
| @@ -192,6 +196,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language | |
| 192 | 
             
                    api_key="",
         | 
| 193 | 
             
                    captioner=shared_captioner,
         | 
| 194 | 
             
                    sam_model=shared_sam_model,
         | 
|  | |
| 195 | 
             
                    text_refiner=text_refiner,
         | 
| 196 | 
             
                    session_id=iface.app_id
         | 
| 197 | 
             
                )
         | 
| @@ -213,6 +218,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language | |
| 213 | 
             
                x, y = input_points[-1]
         | 
| 214 |  | 
| 215 | 
             
                if visual_chatgpt is not None:
         | 
|  | |
| 216 | 
             
                    new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
         | 
| 217 | 
             
                    Image.open(out["crop_save_path"]).save(new_crop_save_path)
         | 
| 218 | 
             
                    point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
         | 
| @@ -273,6 +279,7 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali | |
| 273 | 
             
                    api_key="",
         | 
| 274 | 
             
                    captioner=shared_captioner,
         | 
| 275 | 
             
                    sam_model=shared_sam_model,
         | 
|  | |
| 276 | 
             
                    text_refiner=text_refiner,
         | 
| 277 | 
             
                    session_id=iface.app_id
         | 
| 278 | 
             
                )
         | 
| @@ -325,6 +332,7 @@ def cap_everything(image_input, visual_chatgpt, text_refiner): | |
| 325 | 
             
                    api_key="",
         | 
| 326 | 
             
                    captioner=shared_captioner,
         | 
| 327 | 
             
                    sam_model=shared_sam_model,
         | 
|  | |
| 328 | 
             
                    text_refiner=text_refiner,
         | 
| 329 | 
             
                    session_id=iface.app_id
         | 
| 330 | 
             
                )
         | 
|  | |
| 17 | 
             
            from caption_anything.segmenter import build_segmenter
         | 
| 18 | 
             
            from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
         | 
| 19 | 
             
            from segment_anything import sam_model_registry
         | 
| 20 | 
            +
            import easyocr
         | 
| 21 |  | 
| 22 | 
             
            args = parse_augment()
         | 
| 23 | 
             
            args.segmenter = "huge"
         | 
|  | |
| 30 |  | 
| 31 | 
             
            shared_captioner = build_captioner(args.captioner, args.device, args)
         | 
| 32 | 
             
            shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
         | 
| 33 | 
            +
            ocr_lang = ["ch_tra", "en"]
         | 
| 34 | 
            +
            shared_ocr_reader = easyocr.Reader(ocr_lang)
         | 
| 35 | 
             
            tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
         | 
| 36 | 
             
            shared_chatbot_tools = build_chatbot_tools(tools_dict)
         | 
| 37 |  | 
|  | |
| 59 | 
             
                    return super().preprocess(x)
         | 
| 60 |  | 
| 61 |  | 
| 62 | 
            +
            def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
         | 
| 63 | 
             
                                                   session_id=None):
         | 
| 64 | 
             
                segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
         | 
| 65 | 
             
                captioner = captioner
         | 
| 66 | 
             
                if session_id is not None:
         | 
| 67 | 
             
                    print('Init caption anything for session {}'.format(session_id))
         | 
| 68 | 
            +
                return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
         | 
| 69 |  | 
| 70 |  | 
| 71 | 
             
            def init_openai_api_key(api_key=""):
         | 
|  | |
| 148 | 
             
                    api_key="",
         | 
| 149 | 
             
                    captioner=shared_captioner,
         | 
| 150 | 
             
                    sam_model=shared_sam_model,
         | 
| 151 | 
            +
                    ocr_reader=shared_ocr_reader,
         | 
| 152 | 
             
                    session_id=iface.app_id
         | 
| 153 | 
             
                )
         | 
| 154 | 
             
                model.segmenter.set_image(image_input)
         | 
|  | |
| 157 | 
             
                input_size = model.input_size
         | 
| 158 |  | 
| 159 | 
             
                if visual_chatgpt is not None:
         | 
| 160 | 
            +
                    print('upload_callback: add caption to chatGPT memory')
         | 
| 161 | 
             
                    new_image_path = get_new_image_name('chat_image', func_name='upload')
         | 
| 162 | 
             
                    image_input.save(new_image_path)
         | 
| 163 | 
             
                    visual_chatgpt.current_image = new_image_path
         | 
|  | |
| 196 | 
             
                    api_key="",
         | 
| 197 | 
             
                    captioner=shared_captioner,
         | 
| 198 | 
             
                    sam_model=shared_sam_model,
         | 
| 199 | 
            +
                    ocr_reader=shared_ocr_reader,
         | 
| 200 | 
             
                    text_refiner=text_refiner,
         | 
| 201 | 
             
                    session_id=iface.app_id
         | 
| 202 | 
             
                )
         | 
|  | |
| 218 | 
             
                x, y = input_points[-1]
         | 
| 219 |  | 
| 220 | 
             
                if visual_chatgpt is not None:
         | 
| 221 | 
            +
                    print('inference_click: add caption to chatGPT memory')
         | 
| 222 | 
             
                    new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
         | 
| 223 | 
             
                    Image.open(out["crop_save_path"]).save(new_crop_save_path)
         | 
| 224 | 
             
                    point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
         | 
|  | |
| 279 | 
             
                    api_key="",
         | 
| 280 | 
             
                    captioner=shared_captioner,
         | 
| 281 | 
             
                    sam_model=shared_sam_model,
         | 
| 282 | 
            +
                    ocr_reader=shared_ocr_reader,
         | 
| 283 | 
             
                    text_refiner=text_refiner,
         | 
| 284 | 
             
                    session_id=iface.app_id
         | 
| 285 | 
             
                )
         | 
|  | |
| 332 | 
             
                    api_key="",
         | 
| 333 | 
             
                    captioner=shared_captioner,
         | 
| 334 | 
             
                    sam_model=shared_sam_model,
         | 
| 335 | 
            +
                    ocr_reader=shared_ocr_reader,
         | 
| 336 | 
             
                    text_refiner=text_refiner,
         | 
| 337 | 
             
                    session_id=iface.app_id
         | 
| 338 | 
             
                )
         | 
    	
        caption_anything/captioner/blip2.py
    CHANGED
    
    | @@ -6,6 +6,7 @@ from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
| 6 |  | 
| 7 | 
             
            from caption_anything.utils.utils import is_platform_win, load_image
         | 
| 8 | 
             
            from .base_captioner import BaseCaptioner
         | 
|  | |
| 9 |  | 
| 10 | 
             
            class BLIP2Captioner(BaseCaptioner):
         | 
| 11 | 
             
                def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
         | 
| @@ -33,8 +34,7 @@ class BLIP2Captioner(BaseCaptioner): | |
| 33 | 
             
                    if not self.dialogue:
         | 
| 34 | 
             
                        inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
         | 
| 35 | 
             
                        out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
         | 
| 36 | 
            -
                         | 
| 37 | 
            -
                        caption = [caption.strip() for caption in captions][0]
         | 
| 38 | 
             
                        if self.enable_filter and filter:
         | 
| 39 | 
             
                            print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
         | 
| 40 | 
             
                            clip_score = self.filter_caption(image, caption, args['reference_caption'])
         | 
|  | |
| 6 |  | 
| 7 | 
             
            from caption_anything.utils.utils import is_platform_win, load_image
         | 
| 8 | 
             
            from .base_captioner import BaseCaptioner
         | 
| 9 | 
            +
            import time
         | 
| 10 |  | 
| 11 | 
             
            class BLIP2Captioner(BaseCaptioner):
         | 
| 12 | 
             
                def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
         | 
|  | |
| 34 | 
             
                    if not self.dialogue:
         | 
| 35 | 
             
                        inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
         | 
| 36 | 
             
                        out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
         | 
| 37 | 
            +
                        caption = self.processor.decode(out.sequences[0], skip_special_tokens=True).strip()
         | 
|  | |
| 38 | 
             
                        if self.enable_filter and filter:
         | 
| 39 | 
             
                            print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
         | 
| 40 | 
             
                            clip_score = self.filter_caption(image, caption, args['reference_caption'])
         | 
    	
        caption_anything/model.py
    CHANGED
    
    | @@ -8,6 +8,7 @@ import numpy as np | |
| 8 | 
             
            from PIL import Image
         | 
| 9 | 
             
            import easyocr
         | 
| 10 | 
             
            import copy
         | 
|  | |
| 11 | 
             
            from caption_anything.captioner import build_captioner, BaseCaptioner
         | 
| 12 | 
             
            from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
         | 
| 13 | 
             
            from caption_anything.text_refiner import build_text_refiner
         | 
| @@ -16,14 +17,15 @@ from caption_anything.utils.utils import mask_painter_foreground_all, mask_paint | |
| 16 | 
             
            from caption_anything.utils.densecap_painter import draw_bbox
         | 
| 17 |  | 
| 18 | 
             
            class CaptionAnything:
         | 
| 19 | 
            -
                def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
         | 
| 20 | 
             
                    self.args = args
         | 
| 21 | 
             
                    self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
         | 
| 22 | 
             
                    self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
         | 
| 23 | 
             
                    self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
         | 
|  | |
|  | |
| 24 |  | 
| 25 | 
            -
             | 
| 26 | 
            -
                    self.reader = easyocr.Reader(self.lang)
         | 
| 27 | 
             
                    self.text_refiner = None
         | 
| 28 | 
             
                    if not args.disable_gpt:
         | 
| 29 | 
             
                        if text_refiner is not None:
         | 
| @@ -31,6 +33,7 @@ class CaptionAnything: | |
| 31 | 
             
                        elif api_key != "":
         | 
| 32 | 
             
                            self.init_refiner(api_key)
         | 
| 33 | 
             
                    self.require_caption_prompt = args.captioner == 'blip2'
         | 
|  | |
| 34 |  | 
| 35 | 
             
                @property
         | 
| 36 | 
             
                def image_embedding(self):
         | 
| @@ -213,7 +216,7 @@ class CaptionAnything: | |
| 213 | 
             
                def parse_ocr(self, image, thres=0.2):
         | 
| 214 | 
             
                    width, height = get_image_shape(image)
         | 
| 215 | 
             
                    image = load_image(image, return_type='numpy')
         | 
| 216 | 
            -
                    bounds = self. | 
| 217 | 
             
                    bounds = [bound for bound in bounds if bound[2] > thres]
         | 
| 218 | 
             
                    print('Process OCR Text:\n', bounds)
         | 
| 219 |  | 
| @@ -257,7 +260,7 @@ class CaptionAnything: | |
| 257 | 
             
            if __name__ == "__main__":
         | 
| 258 | 
             
                from caption_anything.utils.parser import parse_augment
         | 
| 259 | 
             
                args = parse_augment()
         | 
| 260 | 
            -
                image_path = ' | 
| 261 | 
             
                image = Image.open(image_path)
         | 
| 262 | 
             
                prompts = [
         | 
| 263 | 
             
                    {
         | 
|  | |
| 8 | 
             
            from PIL import Image
         | 
| 9 | 
             
            import easyocr
         | 
| 10 | 
             
            import copy
         | 
| 11 | 
            +
            import time
         | 
| 12 | 
             
            from caption_anything.captioner import build_captioner, BaseCaptioner
         | 
| 13 | 
             
            from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
         | 
| 14 | 
             
            from caption_anything.text_refiner import build_text_refiner
         | 
|  | |
| 17 | 
             
            from caption_anything.utils.densecap_painter import draw_bbox
         | 
| 18 |  | 
| 19 | 
             
            class CaptionAnything:
         | 
| 20 | 
            +
                def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None):
         | 
| 21 | 
             
                    self.args = args
         | 
| 22 | 
             
                    self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
         | 
| 23 | 
             
                    self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
         | 
| 24 | 
             
                    self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
         | 
| 25 | 
            +
                    self.ocr_lang = ["ch_tra", "en"]
         | 
| 26 | 
            +
                    self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang)
         | 
| 27 |  | 
| 28 | 
            +
             | 
|  | |
| 29 | 
             
                    self.text_refiner = None
         | 
| 30 | 
             
                    if not args.disable_gpt:
         | 
| 31 | 
             
                        if text_refiner is not None:
         | 
|  | |
| 33 | 
             
                        elif api_key != "":
         | 
| 34 | 
             
                            self.init_refiner(api_key)
         | 
| 35 | 
             
                    self.require_caption_prompt = args.captioner == 'blip2'
         | 
| 36 | 
            +
                    print('text_refiner init time: ', time.time() - t0)
         | 
| 37 |  | 
| 38 | 
             
                @property
         | 
| 39 | 
             
                def image_embedding(self):
         | 
|  | |
| 216 | 
             
                def parse_ocr(self, image, thres=0.2):
         | 
| 217 | 
             
                    width, height = get_image_shape(image)
         | 
| 218 | 
             
                    image = load_image(image, return_type='numpy')
         | 
| 219 | 
            +
                    bounds = self.ocr_reader.readtext(image)
         | 
| 220 | 
             
                    bounds = [bound for bound in bounds if bound[2] > thres]
         | 
| 221 | 
             
                    print('Process OCR Text:\n', bounds)
         | 
| 222 |  | 
|  | |
| 260 | 
             
            if __name__ == "__main__":
         | 
| 261 | 
             
                from caption_anything.utils.parser import parse_augment
         | 
| 262 | 
             
                args = parse_augment()
         | 
| 263 | 
            +
                image_path = 'result/wt/memes/87226084.jpg'
         | 
| 264 | 
             
                image = Image.open(image_path)
         | 
| 265 | 
             
                prompts = [
         | 
| 266 | 
             
                    {
         | 
