Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| from transformers import Pix2StructProcessor, Pix2StructVisionModel | |
| from utils import download_default_font, render_header | |
| class Pix2StructForRegression(nn.Module): | |
| def __init__(self, sourcemodel_path, device): | |
| super(Pix2StructForRegression, self).__init__() | |
| self.model = Pix2StructVisionModel.from_pretrained(sourcemodel_path) | |
| self.regression_layer1 = nn.Linear(768, 1536) | |
| self.dropout1 = nn.Dropout(0.1) | |
| self.regression_layer2 = nn.Linear(1536, 768) | |
| self.dropout2 = nn.Dropout(0.1) | |
| self.regression_layer3 = nn.Linear(768, 2) | |
| self.device = device | |
| def forward(self, *args, **kwargs): | |
| outputs = self.model(*args, **kwargs) | |
| sequence_output = outputs.last_hidden_state | |
| first_token_output = sequence_output[:, 0, :] | |
| x = F.relu(self.regression_layer1(first_token_output)) | |
| x = F.relu(self.regression_layer2(x)) | |
| regression_output = torch.sigmoid(self.regression_layer3(x)) | |
| return regression_output | |
| def load_state_dict_file(self, checkpoint_path, strict=True): | |
| state_dict = torch.load(checkpoint_path, map_location=self.device) | |
| self.load_state_dict(state_dict, strict=strict) | |
| class Inference: | |
| def __init__(self) -> None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model, self.processor = self.load_model_and_processor("google/matcha-base", "model/pta-text-v0.1.pt") | |
| def load_model_and_processor(self, model_name, checkpoint_path): | |
| model = Pix2StructForRegression(sourcemodel_path=model_name, device=self.device) | |
| model.load_state_dict_file(checkpoint_path=checkpoint_path) | |
| model.eval() | |
| model = model.to(self.device) | |
| processor = Pix2StructProcessor.from_pretrained(model_name, is_vqa=False) | |
| return model, processor | |
| def prepare_image(self, image, prompt, processor): | |
| image = image.resize((1920, 1080)) | |
| download_default_font_path = download_default_font() | |
| rendered_image, _, render_variables = render_header( | |
| image=image, | |
| header=prompt, | |
| bbox={"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}, | |
| font_path=download_default_font_path, | |
| ) | |
| encoding = processor( | |
| images=rendered_image, | |
| max_patches=2048, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| return encoding, render_variables | |
| def predict_coordinates(self, encoding, model, render_variables): | |
| with torch.no_grad(): | |
| pred_regression_outs = model(flattened_patches=encoding["flattened_patches"], attention_mask=encoding["attention_mask"]) | |
| new_height = render_variables["height"] | |
| new_header_height = render_variables["header_height"] | |
| new_total_height = render_variables["total_height"] | |
| pred_regression_outs[:, 1] = ( | |
| (pred_regression_outs[:, 1] * new_total_height) - new_header_height | |
| ) / new_height | |
| pred_coordinates = pred_regression_outs.squeeze().tolist() | |
| return pred_coordinates | |
| def draw_circle_on_image(self, image, coordinates): | |
| x, y = coordinates[0] * image.width, coordinates[1] * image.height | |
| draw = ImageDraw.Draw(image) | |
| radius = 5 | |
| draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red") | |
| return image | |
| def process_image_and_draw_circle(self, image, prompt): | |
| encoding, render_variables = self.prepare_image(image, prompt, self.processor) | |
| pred_coordinates = self.predict_coordinates(encoding.to(self.device) , self.model, render_variables) | |
| result_image = self.draw_circle_on_image(image, pred_coordinates) | |
| return result_image | |
| def main(): | |
| inference = Inference() | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=inference.process_image_and_draw_circle, | |
| inputs=[gr.Image(type="pil", label = "Upload Image"), | |
| gr.Textbox(label = "Prompt", placeholder="Enter prompt here...")], | |
| outputs=gr.Image(type="pil"), | |
| title="Pix2Struct Image Processing", | |
| description="Upload an image and enter a prompt to see the model's prediction." | |
| ) | |
| iface.launch(server_name="0.0.0.0", port=7860) | |
| if __name__ == "__main__": | |
| main() | |