Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import BertTokenizer, AutoModel | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from sklearn.model_selection import train_test_split | |
| from typing import List | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| import torch, re | |
| import numpy as np | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration, ViTImageProcessor, BertTokenizer, BlipProcessor, BlipForQuestionAnswering, AutoProcessor, AutoModelForCausalLM, DonutProcessor, VisionEncoderDecoderModel, Pix2StructProcessor, Pix2StructForConditionalGeneration, AutoModelForSeq2SeqLM | |
| import librosa | |
| from PIL import Image | |
| from torch.nn.utils import rnn | |
| from gtts import gTTS | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class LabelClassifier(nn.Module): | |
| def __init__(self): | |
| super(LabelClassifier, self).__init__() | |
| self.text_encoder = AutoModel.from_pretrained('bert-base-uncased') | |
| self.image_encoder = AutoModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224') | |
| self.intermediate_dim = 128 | |
| self.fusion = nn.Sequential( | |
| nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, self.intermediate_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| ) | |
| self.classifier = nn.Linear(self.intermediate_dim, 6) # Concatenating BERT output and Swin Transformer output | |
| self.criterion = nn.CrossEntropyLoss() | |
| def forward(self, | |
| input_ids: torch.LongTensor,pixel_values: torch.FloatTensor, attention_mask: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, labels: torch.LongTensor = None): | |
| encoded_text = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
| encoded_image = self.image_encoder(pixel_values=pixel_values) | |
| # print(encoded_text['last_hidden_state'].shape) | |
| # print(encoded_image['last_hidden_state'].shape) | |
| fused_state = self.fusion(torch.cat((encoded_text['pooler_output'], encoded_image['pooler_output']), dim=1)) | |
| # Pass through the classifier | |
| logits = self.classifier(fused_state) | |
| out = {"logits": logits} | |
| if labels is not None: | |
| loss = self.criterion(logits, labels) | |
| out["loss"] = loss | |
| return out | |
| model = LabelClassifier().to(device) | |
| model.load_state_dict(torch.load('classifier.pth', map_location=torch.device('cpu'))) | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| processor = ViTImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224') | |
| # Load the Whisper model in Hugging Face format: | |
| # processor2 = WhisperProcessor.from_pretrained("openai/whisper-medium.en") | |
| # model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium.en") | |
| def m1(que, image): | |
| processor3 = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
| model3 = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
| inputs = processor3(image, que, return_tensors="pt") | |
| out = model3.generate(**inputs) | |
| return processor3.decode(out[0], skip_special_tokens=True) | |
| def m2(que, image): | |
| processor3 = AutoProcessor.from_pretrained("microsoft/git-large-textvqa") | |
| model3 = AutoModelForCausalLM.from_pretrained("microsoft/git-large-textvqa") | |
| pixel_values = processor3(images=image, return_tensors="pt").pixel_values | |
| input_ids = processor3(text=que, add_special_tokens=False).input_ids | |
| input_ids = [processor3.tokenizer.cls_token_id] + input_ids | |
| input_ids = torch.tensor(input_ids).unsqueeze(0) | |
| generated_ids = model3.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) | |
| return processor3.batch_decode(generated_ids, skip_special_tokens=True)[0].split('?', 1)[-1].strip() | |
| def m3(que, image): | |
| # processor3 = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") | |
| # model3 = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") | |
| # model3.to(device) | |
| # prompt = "<s_docvqa><s_question>{que}</s_question><s_answer>" | |
| # decoder_input_ids = processor3.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids | |
| # pixel_values = processor3(image, return_tensors="pt").pixel_values | |
| # outputs = model3.generate( | |
| # pixel_values.to(device), | |
| # decoder_input_ids=decoder_input_ids.to(device), | |
| # max_length=model3.decoder.config.max_position_embeddings, | |
| # pad_token_id=processor3.tokenizer.pad_token_id, | |
| # eos_token_id=processor3.tokenizer.eos_token_id, | |
| # use_cache=True, | |
| # bad_words_ids=[[processor3.tokenizer.unk_token_id]], | |
| # return_dict_in_generate=True, | |
| # ) | |
| # sequence = processor3.batch_decode(outputs.sequences)[0] | |
| # sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "") | |
| # sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token | |
| # return processor3.token2json(sequence)['answer'] | |
| model3 = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large") | |
| processor3 = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large") | |
| inputs = processor3(images=image, text=que, return_tensors="pt") | |
| predictions = model3.generate(**inputs) | |
| return processor3.decode(predictions[0], skip_special_tokens=True) | |
| def m4(que, image): | |
| processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v1') | |
| model3 = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-plotqa-v1') | |
| inputs = processor3(images=image, text=que, return_tensors="pt") | |
| predictions = model3.generate(**inputs, max_new_tokens=512) | |
| return processor3.decode(predictions[0], skip_special_tokens=True) | |
| def m5(que, image): | |
| model3 = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ocrvqa-large") | |
| processor3 = Pix2StructProcessor.from_pretrained("google/pix2struct-ocrvqa-large") | |
| inputs = processor3(images=image, text=que, return_tensors="pt") | |
| predictions = model3.generate(**inputs) | |
| return processor3.decode(predictions[0], skip_special_tokens=True) | |
| def m6(que, image): | |
| # model3 = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-infographics-vqa-large") | |
| # processor3 = Pix2StructProcessor.from_pretrained("google/pix2struct-infographics-vqa-large") | |
| # inputs = processor3(images=image, text=que, return_tensors="pt") | |
| # predictions = model3.generate(**inputs) | |
| # return processor3.decode(predictions[0], skip_special_tokens=True) | |
| processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v1') | |
| model3 = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-plotqa-v1') | |
| inputs = processor3(images=image, text=que, return_tensors="pt") | |
| predictions = model3.generate(**inputs, max_new_tokens=512) | |
| return processor3.decode(predictions[0], skip_special_tokens=True) | |
| def predict_answer(category, que, image): | |
| if category == 0: | |
| return m1(que, image) | |
| elif category == 1: | |
| return m2(que, image) | |
| elif category == 2: | |
| return m3(que, image) | |
| elif category == 3: | |
| return m4(que, image) | |
| elif category == 4: | |
| return m5(que, image) | |
| else: | |
| return m6(que, image) | |
| def transcribe_audio(audio): | |
| # print(audio) | |
| processor2 = WhisperProcessor.from_pretrained("openai/whisper-large-v3",language='en') | |
| model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") | |
| sampling_rate = audio[0] | |
| audio_data = audio[1] | |
| # print(np.array([audio_data]).shape) | |
| audio_data_float = np.array(audio_data).astype(np.float32) | |
| resampled_audio_data = librosa.resample(audio_data_float, orig_sr=sampling_rate, target_sr=16000) | |
| # Use the model and processor to transcribe the audio: | |
| input_features = processor2( | |
| resampled_audio_data, sampling_rate=16000, return_tensors="pt" | |
| ).input_features | |
| # Generate token ids | |
| predicted_ids = model2.generate(input_features) | |
| # Decode token ids to text | |
| transcription = processor2.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return transcription | |
| def predict_category(que, input_image): | |
| # print(type(input_image)) | |
| # print(input_image) | |
| encoded_text = tokenizer( | |
| text=que, | |
| padding='longest', | |
| max_length=24, | |
| truncation=True, | |
| return_tensors='pt', | |
| return_token_type_ids=True, | |
| return_attention_mask=True, | |
| ) | |
| encoded_image = processor(input_image, return_tensors='pt').to(device) | |
| dict = { | |
| 'input_ids': encoded_text['input_ids'].to(device), | |
| 'token_type_ids': encoded_text['token_type_ids'].to(device), | |
| 'attention_mask': encoded_text['attention_mask'].to(device), | |
| 'pixel_values': encoded_image['pixel_values'].to(device) | |
| } | |
| output = model(input_ids=dict['input_ids'],token_type_ids=dict['token_type_ids'],attention_mask=dict['attention_mask'],pixel_values=dict['pixel_values']) | |
| preds = output["logits"].argmax(axis=-1).cpu().numpy() | |
| return preds[0] | |
| def combine(audio, input_image, text_question=""): | |
| if audio: | |
| que = transcribe_audio(audio) | |
| else: | |
| que = text_question | |
| image = Image.fromarray(input_image).convert('RGB') | |
| category = predict_category(que, image) | |
| answer = predict_answer(category, que, image) | |
| tts = gTTS(answer) | |
| tts.save('answer.mp3') | |
| return que, answer, 'answer.mp3', category | |
| # Define the Gradio interface for recording audio, text input, and image upload | |
| model_interface = gr.Interface(fn=combine, | |
| inputs=[gr.Microphone(label="Ask your question"), | |
| gr.Image(label="Upload the image"), | |
| gr.Textbox(label="Text Question")], | |
| outputs=[gr.Text(label="Transcribed Question"), | |
| gr.Text(label="Answer"), | |
| gr.Audio(label="Audio Answer"), | |
| gr.Text(label="Category")]) | |
| # Launch the Gradio interface | |
| model_interface.launch(debug=True) |