Spaces:
Running
Running
| import gradio as gr | |
| from transformers import ( | |
| PaliGemmaProcessor, | |
| PaliGemmaForConditionalGeneration, | |
| ) | |
| from transformers.image_utils import load_image | |
| import torch | |
| import os | |
| import spaces # Import the spaces module | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| def load_model(): | |
| """Load PaliGemma2 model and processor with Hugging Face token.""" | |
| token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Retrieve token from environment variable | |
| if not token: | |
| raise ValueError( | |
| "Hugging Face API token not found. Please set it in the environment variables." | |
| ) | |
| # Load the processor and model using the correct identifier | |
| model_id = "google/paligemma2-10b-pt-448" | |
| processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = PaliGemmaForConditionalGeneration.from_pretrained( | |
| model_id, torch_dtype=torch.bfloat16, use_auth_token=token | |
| ).to(device).eval() | |
| return processor, model | |
| # Increased timeout to 120 seconds | |
| def process_image_and_text(image_pil, num_beams, temperature, seed): | |
| """Extract text from image using PaliGemma2.""" | |
| try: | |
| processor, model = load_model() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load the image using load_image | |
| image = load_image(image_pil) | |
| # Add <image> token to the beginning of the text prompt | |
| text_input = " " | |
| # Use the provided text input | |
| model_inputs = processor(text=text_input, images=image, return_tensors="pt").to( | |
| device, dtype=torch.bfloat16 | |
| ) | |
| input_len = model_inputs["input_ids"].shape[-1] | |
| torch.manual_seed(seed) # Set random seed for reproducibility | |
| with torch.inference_mode(): | |
| generation = model.generate(**model_inputs, max_new_tokens=200, do_sample=True, num_beams=num_beams, temperature=temperature) | |
| generation = generation[0][input_len:] | |
| decoded = processor.decode(generation, skip_special_tokens=True) | |
| return decoded | |
| except Exception as e: | |
| print(f"Error during GPU task: {e}") | |
| raise gr.Error(f"GPU task failed: {e}") | |
| if __name__ == "__main__": | |
| iface = gr.Interface( | |
| fn=process_image_and_text, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload an image"), | |
| gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Number of Beams"), | |
| gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature"), | |
| gr.Number(label="Random Seed", value=0, precision=0), | |
| ], | |
| outputs=gr.Textbox(label="Generated Text"), | |
| title="PaliGemma2 Image to Text", | |
| description="Upload an image and the model will generate text.", | |
| ) | |
| iface.launch() |