Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel | |
| from PIL import Image | |
| import gradio as gr | |
| # Model definition and setup | |
| class VisionLanguageModel(nn.Module): | |
| def __init__(self): | |
| super(VisionLanguageModel, self).__init__() | |
| self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') | |
| self.language_model = BertModel.from_pretrained('bert-base-uncased') | |
| self.classifier = nn.Linear( | |
| self.vision_model.config.hidden_size + self.language_model.config.hidden_size, | |
| 2 # Number of classes: benign or malignant | |
| ) | |
| def forward(self, input_ids, attention_mask, pixel_values): | |
| vision_outputs = self.vision_model(pixel_values=pixel_values) | |
| vision_pooled_output = vision_outputs.pooler_output | |
| language_outputs = self.language_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| language_pooled_output = language_outputs.pooler_output | |
| combined_features = torch.cat( | |
| (vision_pooled_output, language_pooled_output), | |
| dim=1 | |
| ) | |
| logits = self.classifier(combined_features) | |
| return logits | |
| model = VisionLanguageModel() | |
| model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True)) | |
| model.eval() | |
| tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
| feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') | |
| def predict(image, text_input): | |
| image = feature_extractor(images=image, return_tensors="pt").pixel_values | |
| encoding = tokenizer( | |
| text_input, | |
| add_special_tokens=True, | |
| max_length=256, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=encoding['input_ids'], | |
| attention_mask=encoding['attention_mask'], | |
| pixel_values=image | |
| ) | |
| _, prediction = torch.max(outputs, dim=1) | |
| return prediction.item() # 1 for Malignant, 0 for Benign | |
| # Enhanced UI with black text | |
| with gr.Blocks(css=""" | |
| body { | |
| color: black; | |
| } | |
| .benign, .malignant { | |
| background-color: white; | |
| border: 1px solid lightgray; | |
| padding: 10px; | |
| border-radius: 5px; | |
| color: black; | |
| } | |
| .benign.correct, .malignant.correct { | |
| background-color: lightgreen; | |
| color: black; | |
| } | |
| """) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🩺 SKIN LESION CLASSIFICATION | |
| Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Skin Lesion Image") | |
| text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## PREDICTION RESULTS") | |
| benign_output = gr.HTML("<div class='benign'>Benign</div>") | |
| malignant_output = gr.HTML("<div class='malignant'>Malignant</div>") | |
| def display_prediction(image, text_input): | |
| prediction = predict(image, text_input) | |
| benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "") | |
| malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "") | |
| return benign_html, malignant_html | |
| # Submit button and prediction outputs | |
| submit_btn = gr.Button("Get Prediction") | |
| submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output]) | |
| demo.launch() | |