Spaces:
Runtime error
Runtime error
| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| --> | |
| # Image captioning | |
| [[open-in-colab]] | |
| Image captioning is the task of predicting a caption for a given image. Common real world applications of it include | |
| aiding visually impaired people that can help them navigate through different situations. Therefore, image captioning | |
| helps to improve content accessibility for people by describing images to them. | |
| This guide will show you how to: | |
| * Fine-tune an image captioning model. | |
| * Use the fine-tuned model for inference. | |
| Before you begin, make sure you have all the necessary libraries installed: | |
| ```bash | |
| pip install transformers datasets evaluate -q | |
| pip install jiwer -q | |
| ``` | |
| We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in: | |
| ```python | |
| from huggingface_hub import notebook_login | |
| notebook_login() | |
| ``` | |
| ## Load the Pokémon BLIP captions dataset | |
| Use the 🤗 Dataset library to load a dataset that consists of {image-caption} pairs. To create your own image captioning dataset | |
| in PyTorch, you can follow [this notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb). | |
| ```python | |
| from datasets import load_dataset | |
| ds = load_dataset("lambdalabs/pokemon-blip-captions") | |
| ds | |
| ``` | |
| ```bash | |
| DatasetDict({ | |
| train: Dataset({ | |
| features: ['image', 'text'], | |
| num_rows: 833 | |
| }) | |
| }) | |
| ``` | |
| The dataset has two features, `image` and `text`. | |
| <Tip> | |
| Many image captioning datasets contain multiple captions per image. In those cases, a common strategy is to randomly sample a caption amongst the available ones during training. | |
| </Tip> | |
| Split the dataset’s train split into a train and test set with the [~datasets.Dataset.train_test_split] method: | |
| ```python | |
| ds = ds["train"].train_test_split(test_size=0.1) | |
| train_ds = ds["train"] | |
| test_ds = ds["test"] | |
| ``` | |
| Let's visualize a couple of samples from the training set. | |
| ```python | |
| from textwrap import wrap | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def plot_images(images, captions): | |
| plt.figure(figsize=(20, 20)) | |
| for i in range(len(images)): | |
| ax = plt.subplot(1, len(images), i + 1) | |
| caption = captions[i] | |
| caption = "\n".join(wrap(caption, 12)) | |
| plt.title(caption) | |
| plt.imshow(images[i]) | |
| plt.axis("off") | |
| sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)] | |
| sample_captions = [train_ds[i]["text"] for i in range(5)] | |
| plot_images(sample_images_to_visualize, sample_captions) | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png" alt="Sample training images"/> | |
| </div> | |
| ## Preprocess the dataset | |
| Since the dataset has two modalities (image and text), the pre-processing pipeline will preprocess images and the captions. | |
| To do so, load the processor class associated with the model you are about to fine-tune. | |
| ```python | |
| from transformers import AutoProcessor | |
| checkpoint = "microsoft/git-base" | |
| processor = AutoProcessor.from_pretrained(checkpoint) | |
| ``` | |
| The processor will internally pre-process the image (which includes resizing, and pixel scaling) and tokenize the caption. | |
| ```python | |
| def transforms(example_batch): | |
| images = [x for x in example_batch["image"]] | |
| captions = [x for x in example_batch["text"]] | |
| inputs = processor(images=images, text=captions, padding="max_length") | |
| inputs.update({"labels": inputs["input_ids"]}) | |
| return inputs | |
| train_ds.set_transform(transforms) | |
| test_ds.set_transform(transforms) | |
| ``` | |
| With the dataset ready, you can now set up the model for fine-tuning. | |
| ## Load a base model | |
| Load the ["microsoft/git-base"](https://huggingface.co/microsoft/git-base) into a [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) object. | |
| ```python | |
| from transformers import AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
| ``` | |
| ## Evaluate | |
| Image captioning models are typically evaluated with the [Rouge Score](https://huggingface.co/spaces/evaluate-metric/rouge) or [Word Error Rate](https://huggingface.co/spaces/evaluate-metric/wer). For this guide, you will use the Word Error Rate (WER). | |
| We use the 🤗 Evaluate library to do so. For potential limitations and other gotchas of the WER, refer to [this guide](https://huggingface.co/spaces/evaluate-metric/wer). | |
| ```python | |
| from evaluate import load | |
| import torch | |
| wer = load("wer") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| predicted = logits.argmax(-1) | |
| decoded_labels = processor.batch_decode(labels, skip_special_tokens=True) | |
| decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True) | |
| wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels) | |
| return {"wer_score": wer_score} | |
| ``` | |
| ## Train! | |
| Now, you are ready to start fine-tuning the model. You will use the 🤗 [`Trainer`] for this. | |
| First, define the training arguments using [`TrainingArguments`]. | |
| ```python | |
| from transformers import TrainingArguments, Trainer | |
| model_name = checkpoint.split("/")[1] | |
| training_args = TrainingArguments( | |
| output_dir=f"{model_name}-pokemon", | |
| learning_rate=5e-5, | |
| num_train_epochs=50, | |
| fp16=True, | |
| per_device_train_batch_size=32, | |
| per_device_eval_batch_size=32, | |
| gradient_accumulation_steps=2, | |
| save_total_limit=3, | |
| evaluation_strategy="steps", | |
| eval_steps=50, | |
| save_strategy="steps", | |
| save_steps=50, | |
| logging_steps=50, | |
| remove_unused_columns=False, | |
| push_to_hub=True, | |
| label_names=["labels"], | |
| load_best_model_at_end=True, | |
| ) | |
| ``` | |
| Then pass them along with the datasets and the model to 🤗 Trainer. | |
| ```python | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=test_ds, | |
| compute_metrics=compute_metrics, | |
| ) | |
| ``` | |
| To start training, simply call [`~Trainer.train`] on the [`Trainer`] object. | |
| ```python | |
| trainer.train() | |
| ``` | |
| You should see the training loss drop smoothly as training progresses. | |
| Once training is completed, share your model to the Hub with the [`~Trainer.push_to_hub`] method so everyone can use your model: | |
| ```python | |
| trainer.push_to_hub() | |
| ``` | |
| ## Inference | |
| Take a sample image from `test_ds` to test the model. | |
| ```python | |
| from PIL import Image | |
| import requests | |
| url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png" | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| image | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png" alt="Test image"/> | |
| </div> | |
| Prepare image for the model. | |
| ```python | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| pixel_values = inputs.pixel_values | |
| ``` | |
| Call [`generate`] and decode the predictions. | |
| ```python | |
| generated_ids = model.generate(pixel_values=pixel_values, max_length=50) | |
| generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| print(generated_caption) | |
| ``` | |
| ```bash | |
| a drawing of a pink and blue pokemon | |
| ``` | |
| Looks like the fine-tuned model generated a pretty good caption! | |