Create README.md
Browse files
README.md
CHANGED
|
@@ -1,13 +1,146 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: 'GCViT: Global Context Vision Transformer'
|
| 3 |
+
colorFrom: indigo
|
| 4 |
+
---
|
| 5 |
+
<h1 align="center">
|
| 6 |
+
<p><a href='https://arxiv.org/pdf/2206.09959v1.pdf'>GCViT: Global Context Vision Transformer</a></p>
|
| 7 |
+
</h1>
|
| 8 |
+
<div align=center><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG" width=800></div>
|
| 9 |
+
<p align="center">
|
| 10 |
+
<a href="https://github.com/awsaf49/gcvit-tf/blob/main/LICENSE.md">
|
| 11 |
+
<img src="https://img.shields.io/badge/License-MIT-yellow.svg">
|
| 12 |
+
</a>
|
| 13 |
+
<img alt="python" src="https://img.shields.io/badge/python-%3E%3D3.6-blue?logo=python">
|
| 14 |
+
<img alt="tensorflow" src="https://img.shields.io/badge/tensorflow-%3E%3D2.4.1-orange?logo=tensorflow">
|
| 15 |
+
<div align=center><p>
|
| 16 |
+
<a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/π€%20Hugging%20Face-Spaces-yellow.svg"></a>
|
| 17 |
+
<a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
| 18 |
+
<a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>
|
| 19 |
+
</p></div>
|
| 20 |
+
<h2 align="center">
|
| 21 |
+
<p>Tensorflow 2.0 Implementation of GCViT</p>
|
| 22 |
+
</h2>
|
| 23 |
+
</p>
|
| 24 |
+
<p align="center">
|
| 25 |
+
This library implements <b>GCViT</b> using Tensorflow 2.0 specifically in <code>tf.keras.Model</code> manner to get PyTorch flavor.
|
| 26 |
+
</p>
|
| 27 |
+
|
| 28 |
+
## Update
|
| 29 |
+
* **15 Jan 2023** : `GCViTLarge` model added with ckpt.
|
| 30 |
+
* **3 Sept 2022** : Annotated [kaggle-notebook](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer) based on this project won [Kaggle ML Research Spotlight: August 2022](https://www.kaggle.com/discussions/general/349817).
|
| 31 |
+
* **19 Aug 2022** : This project got acknowledged by [Official](https://github.com/NVlabs/GCVit) repo [here](https://github.com/NVlabs/GCVit#third-party-implementations-and-resources)
|
| 32 |
+
|
| 33 |
+
## Model
|
| 34 |
+
* Architecture:
|
| 35 |
+
|
| 36 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch.PNG">
|
| 37 |
+
|
| 38 |
+
* Local Vs Global Attention:
|
| 39 |
+
|
| 40 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG">
|
| 41 |
+
|
| 42 |
+
## Result
|
| 43 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/result.PNG" width=900>
|
| 44 |
+
|
| 45 |
+
Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on **ImageNetV2-Test** data,
|
| 46 |
+
|
| 47 |
+
| Model | Acc@1 | Acc@5 | #Params |
|
| 48 |
+
|--------------|-------|-------|---------|
|
| 49 |
+
| GCViT-XXTiny | 0.663 | 0.873 | 12M |
|
| 50 |
+
| GCViT-XTiny | 0.685 | 0.885 | 20M |
|
| 51 |
+
| GCViT-Tiny | 0.708 | 0.899 | 28M |
|
| 52 |
+
| GCViT-Small | 0.720 | 0.901 | 51M |
|
| 53 |
+
| GCViT-Base | 0.731 | 0.907 | 90M |
|
| 54 |
+
| GCViT-Large | 0.734 | 0.913 | 202M |
|
| 55 |
+
|
| 56 |
+
## Installation
|
| 57 |
+
```bash
|
| 58 |
+
pip install -U gcvit
|
| 59 |
+
# or
|
| 60 |
+
# pip install -U git+https://github.com/awsaf49/gcvit-tf
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Usage
|
| 64 |
+
Load model using following codes,
|
| 65 |
+
```py
|
| 66 |
+
from gcvit import GCViTTiny
|
| 67 |
+
model = GCViTTiny(pretrain=True)
|
| 68 |
+
```
|
| 69 |
+
Simple code to check model's prediction,
|
| 70 |
+
```py
|
| 71 |
+
from skimage.data import chelsea
|
| 72 |
+
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
|
| 73 |
+
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
|
| 74 |
+
pred = model(img).numpy()
|
| 75 |
+
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])
|
| 76 |
+
```
|
| 77 |
+
Prediction:
|
| 78 |
+
```py
|
| 79 |
+
[('n02124075', 'Egyptian_cat', 0.9194835),
|
| 80 |
+
('n02123045', 'tabby', 0.009686623),
|
| 81 |
+
('n02123159', 'tiger_cat', 0.0061576385),
|
| 82 |
+
('n02127052', 'lynx', 0.0011503297),
|
| 83 |
+
('n02883205', 'bow_tie', 0.00042479983)]
|
| 84 |
+
```
|
| 85 |
+
For feature extraction:
|
| 86 |
+
```py
|
| 87 |
+
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
|
| 88 |
+
model.reset_classifier(num_classes=0, head_act=None)
|
| 89 |
+
feature = model(img)
|
| 90 |
+
print(feature.shape)
|
| 91 |
+
```
|
| 92 |
+
Feature:
|
| 93 |
+
```py
|
| 94 |
+
(None, 512)
|
| 95 |
+
```
|
| 96 |
+
For feature map:
|
| 97 |
+
```py
|
| 98 |
+
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
|
| 99 |
+
feature = model.forward_features(img)
|
| 100 |
+
print(feature.shape)
|
| 101 |
+
```
|
| 102 |
+
Feature map:
|
| 103 |
+
```py
|
| 104 |
+
(None, 7, 7, 512)
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Live-Demo
|
| 108 |
+
* For live demo on Image Classification & Grad-CAM, with **ImageNet** weights, click <a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/Try%20on-Gradio-orange"></a> powered by π€ Space and Gradio. here's an example,
|
| 109 |
+
|
| 110 |
+
<a href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="image/gradio_demo.JPG" height=500></a>
|
| 111 |
+
|
| 112 |
+
## Example
|
| 113 |
+
For working training example checkout these notebooks on **Google Colab** <a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> & **Kaggle** <a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>.
|
| 114 |
+
|
| 115 |
+
Here is grad-cam result after training on Flower Classification Dataset,
|
| 116 |
+
|
| 117 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/flower_gradcam.PNG" height=500>
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
## To Do
|
| 122 |
+
- [ ] Segmentation Pipeline
|
| 123 |
+
- [x] New updated weights have been added.
|
| 124 |
+
- [x] Working training example in Colab & Kaggle.
|
| 125 |
+
- [x] GradCAM showcase.
|
| 126 |
+
- [x] Gradio Demo.
|
| 127 |
+
- [x] Build model with `tf.keras.Model`.
|
| 128 |
+
- [x] Port weights from official repo.
|
| 129 |
+
- [x] Support for `TPU`.
|
| 130 |
+
|
| 131 |
+
## Acknowledgement
|
| 132 |
+
* [GCVit](https://github.com/NVlabs/GCVit) (Official)
|
| 133 |
+
* [Swin-Transformer-TF](https://github.com/rishigami/Swin-Transformer-TF)
|
| 134 |
+
* [tfgcvit](https://github.com/shkarupa-alex/tfgcvit/tree/develop/tfgcvit)
|
| 135 |
+
* [keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_model)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## Citation
|
| 139 |
+
```bibtex
|
| 140 |
+
@article{hatamizadeh2022global,
|
| 141 |
+
title={Global Context Vision Transformers},
|
| 142 |
+
author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
|
| 143 |
+
journal={arXiv preprint arXiv:2206.09959},
|
| 144 |
+
year={2022}
|
| 145 |
+
}
|
| 146 |
+
```
|