Spaces:
Paused
Paused
| """PaliGemma demo gradio app.""" | |
| import datetime | |
| import functools | |
| import glob | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import gradio as gr | |
| import jax | |
| import PIL.Image | |
| import gradio_helpers | |
| import models | |
| import paligemma_parse | |
| INTRO_TEXT = """🤲 PaliGemma demo\n\n | |
| | [Paper](https://arxiv.org/abs/2407.07726) | |
| | [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | |
| | [HF blog post](https://huggingface.co/blog/paligemma) | |
| | [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024) | |
| | [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) | |
| | [Demo](https://huggingface.co/spaces/google/paligemma) | |
| |\n\n | |
| [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, | |
| inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and | |
| built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) | |
| vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile | |
| model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question | |
| answering, text reading, object detection and object segmentation. | |
| \n\n | |
| This space includes models fine-tuned on a mix of downstream tasks. | |
| See the [blog post](https://huggingface.co/blog/paligemma) and | |
| [README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | |
| for detailed information how to use and fine-tune PaliGemma models. | |
| \n\n | |
| **This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. | |
| """ | |
| make_image = lambda value, visible: gr.Image( | |
| value, label='Image', type='filepath', visible=visible) | |
| make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image') | |
| make_highlighted_text = functools.partial(gr.HighlightedText, label='Output') | |
| # https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1 | |
| COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
| def compute(image, prompt, model_name, sampler): | |
| """Runs model inference.""" | |
| if image is None: | |
| raise gr.Error('Image required') | |
| logging.info('prompt="%s"', prompt) | |
| if isinstance(image, str): | |
| image = PIL.Image.open(image) | |
| if gradio_helpers.should_mock(): | |
| logging.warning('Mocking response') | |
| time.sleep(2.) | |
| output = paligemma_parse.EXAMPLE_STRING | |
| else: | |
| if not model_name: | |
| raise gr.Error('Models not loaded yet') | |
| output = models.generate(model_name, sampler, image, prompt) | |
| logging.info('output="%s"', output) | |
| width, height = image.size | |
| objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True) | |
| labels = set(obj.get('name') for obj in objs if obj.get('name')) | |
| color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} | |
| highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] | |
| annotated_image = ( | |
| image, | |
| [ | |
| ( | |
| obj['mask'] if obj.get('mask') is not None else obj['xyxy'], | |
| obj['name'] or '', | |
| ) | |
| for obj in objs | |
| if 'mask' in obj or 'xyxy' in obj | |
| ], | |
| ) | |
| has_annotations = bool(annotated_image[1]) | |
| return ( | |
| make_highlighted_text( | |
| highlighted_text, visible=True, color_map=color_map), | |
| make_image(image, visible=not has_annotations), | |
| make_annotated_image( | |
| annotated_image, visible=has_annotations, width=width, height=height, | |
| color_map=color_map), | |
| ) | |
| def warmup(model_name): | |
| image = PIL.Image.new('RGB', [1, 1]) | |
| _ = compute(image, '', model_name, 'greedy') | |
| def reset(): | |
| return ( | |
| '', make_highlighted_text('', visible=False), | |
| make_image(None, visible=True), make_annotated_image(None, visible=False), | |
| ) | |
| def create_app(): | |
| """Creates demo UI.""" | |
| make_model = lambda choices: gr.Dropdown( | |
| value=(choices + [''])[0], | |
| choices=choices, | |
| label='Model', | |
| visible=bool(choices), | |
| ) | |
| make_prompt = lambda value, visible=True: gr.Textbox( | |
| value, label='Prompt', visible=visible) | |
| with gr.Blocks() as demo: | |
| ##### Main UI structure. | |
| gr.Markdown(INTRO_TEXT) | |
| with gr.Row(): | |
| image = make_image(None, visible=True) # input | |
| annotated_image = make_annotated_image(None, visible=False) # output | |
| with gr.Column(): | |
| with gr.Row(): | |
| prompt = make_prompt('', visible=True) | |
| model_info = gr.Markdown(label='Model Info') | |
| with gr.Row(): | |
| model = make_model([]) | |
| samplers = [ | |
| 'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)'] | |
| sampler = gr.Dropdown( | |
| value=samplers[0], choices=samplers, label='Decoding' | |
| ) | |
| with gr.Row(): | |
| run = gr.Button('Run', variant='primary') | |
| clear = gr.Button('Clear') | |
| highlighted_text = make_highlighted_text('', visible=False) | |
| ##### UI logic. | |
| def update_ui(model, prompt): | |
| prompt = make_prompt(prompt, visible=True) | |
| model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}' | |
| return [prompt, model_info] | |
| gr.on( | |
| [model.change], | |
| update_ui, | |
| [model, prompt], | |
| [prompt, model_info], | |
| ) | |
| gr.on( | |
| [run.click, prompt.submit], | |
| compute, | |
| [image, prompt, model, sampler], | |
| [highlighted_text, image, annotated_image], | |
| ) | |
| clear.click( | |
| reset, None, [prompt, highlighted_text, image, annotated_image] | |
| ) | |
| ##### Examples. | |
| gr.set_static_paths(['examples/']) | |
| all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')] | |
| logging.info('loaded %d examples', len(all_examples)) | |
| example_image = gr.Image( | |
| label='Image', visible=False) # proxy, never visible | |
| example_model = gr.Text( | |
| label='Model', visible=False) # proxy, never visible | |
| example_prompt = gr.Text( | |
| label='Prompt', visible=False) # proxy, never visible | |
| example_license = gr.Markdown( | |
| label='Image License', visible=False) # placeholder, never visible | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| f'examples/{ex["name"]}.jpg', | |
| ex['prompt'], | |
| ex['model'], | |
| ex['license'], | |
| ] | |
| for ex in all_examples | |
| if ex['model'] in models.MODELS | |
| ], | |
| inputs=[example_image, example_prompt, example_model, example_license], | |
| ) | |
| ##### Examples UI logic. | |
| example_image.change( | |
| lambda image_path: ( | |
| make_image(image_path, visible=True), | |
| make_annotated_image(None, visible=False), | |
| make_highlighted_text('', visible=False), | |
| ), | |
| example_image, | |
| [image, annotated_image, highlighted_text], | |
| ) | |
| def example_model_changed(model): | |
| if model not in gradio_helpers.get_paths(): | |
| raise gr.Error(f'Model "{model}" not loaded!') | |
| return model | |
| example_model.change(example_model_changed, example_model, model) | |
| example_prompt.change(make_prompt, example_prompt, prompt) | |
| ##### Status. | |
| status = gr.Markdown(f'Startup: {datetime.datetime.now()}') | |
| gpu_kind = gr.Markdown(f'GPU=?') | |
| demo.load( | |
| lambda: [ | |
| gradio_helpers.get_status(), | |
| make_model(list(gradio_helpers.get_paths())), | |
| ], | |
| None, | |
| [status, model], | |
| ) | |
| def get_gpu_kind(): | |
| device = jax.devices()[0] | |
| if not gradio_helpers.should_mock() and device.platform != 'gpu': | |
| raise gr.Error('GPU not visible to JAX!') | |
| return f'GPU={device.device_kind}' | |
| demo.load(get_gpu_kind, None, gpu_kind) | |
| return demo | |
| if __name__ == '__main__': | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s') | |
| logging.info('JAX devices: %s', jax.devices()) | |
| for k, v in os.environ.items(): | |
| logging.info('environ["%s"] = %r', k, v) | |
| gradio_helpers.set_warmup_function(warmup) | |
| for name, (repo, filename, revision) in models.MODELS.items(): | |
| gradio_helpers.register_download(name, repo, filename, revision) | |
| create_app().queue().launch() | |