| import tensorflow as tf | |
| import gradio as gr | |
| import gcvit | |
| from gcvit.utils import get_gradcam_model, get_gradcam_prediction | |
| def predict_fn(image, model_name): | |
| """A predict function that will be invoked by gradio.""" | |
| model = getattr(gcvit, model_name)(pretrain=True) | |
| gradcam_model = get_gradcam_model(model) | |
| preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None) | |
| preds = {x[1]:float(x[2]) for x in preds} | |
| return [preds, overlay] | |
| demo = gr.Interface( | |
| fn=predict_fn, | |
| inputs=[ | |
| gr.inputs.Image(label="Input Image"), | |
| gr.Radio(['GCViTXXTiny', 'GCViTXTiny', 'GCViTTiny', | |
| 'GCViTSmall', 'GCViTBase','GCViTLarge'], value='GCViTXXTiny', label='Model Name') | |
| ], | |
| outputs=[ | |
| gr.outputs.Label(label="Prediction"), | |
| gr.inputs.Image(label="GradCAM"), | |
| ], | |
| title="Global Context Vision Transformer (GCViT) Demo", | |
| description="Image Classification with GCViT Model using ImageNet Pretrain Weights.", | |
| examples=[ | |
| ["example/hot_air_ballon.jpg", 'GCViTXXTiny'], | |
| ["example/chelsea.png", 'GCViTXXTiny'], | |
| ["example/penguin.JPG", 'GCViTXXTiny'], | |
| ["example/bus.jpg", 'GCViTXXTiny'], | |
| ], | |
| ) | |
| demo.launch() |