Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from facenet_pytorch import MTCNN, InceptionResnetV1 | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import pickle | |
| DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| print(f'Running on device: {DEVICE.upper()}') | |
| torch.load('resnetinceptionv1_final.pth',map_location='cpu') | |
| mtcnn = MTCNN( | |
| select_largest=False, | |
| post_process=False, | |
| device=DEVICE | |
| ).to(DEVICE).eval() | |
| model = InceptionResnetV1( | |
| pretrained="vggface2", | |
| classify=True, | |
| num_classes=1, | |
| device=DEVICE | |
| ) | |
| model.load_state_dict(torch.load('resnetinceptionv1_final.pth',map_location='cpu')) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("MTCNN & Classfier models loaded") | |
| # Abrimos el fichero pickle de ejemplos de imagenes | |
| with open('file_examples.pkl','rb') as file: | |
| examples=pickle.load(file) | |
| #EXAMPLES_FOLDER = 'examples' | |
| #examples_names = os.listdir(EXAMPLES_FOLDER) | |
| #examples = [] | |
| #for example_name in examples_names: | |
| # example_path = os.path.join(EXAMPLES_FOLDER, example_name) | |
| # label = example_name.split('_')[0] | |
| # example = { | |
| # 'path': example_path, | |
| # 'label': label | |
| # } | |
| # examples.append(example) | |
| def predict(input_image:Image.Image): | |
| """Predict the label of the input_image""" | |
| face = mtcnn(input_image) | |
| if face is None: | |
| raise Exception('No face detected') | |
| face = face.unsqueeze(0) # add the batch dimension | |
| face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) | |
| # convert the face into a numpy array to be able to plot it | |
| face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() | |
| face = face.to(DEVICE) | |
| face = face.to(torch.float32) | |
| face = face / 255.0 | |
| with torch.no_grad(): | |
| output = torch.sigmoid(model(face).squeeze(0)) | |
| prediction = "real" if output.item() < 0.5 else "fake" | |
| real_prediction = 1 - output.item() | |
| fake_prediction = output.item() | |
| confidences = { | |
| 'real': real_prediction, | |
| 'fake': fake_prediction | |
| } | |
| return confidences, face_image_to_plot | |
| for i in range(10): | |
| example = examples[8] | |
| #example_img = example['path'] | |
| example_img='fake_frame_0.jpg' | |
| example_label = example['label'] | |
| print(f"True label: {example_label}") | |
| example_img = Image.open(example_img) | |
| confidences, _ = predict(example_img) | |
| if confidences['real'] > 0.5: | |
| print("Predicted label: real") | |
| else: | |
| print("Predicted label: fake") | |
| print() | |
| title='Fake or not Fake? that is the question' | |
| description='Modelo de deeplearning para detectar imagenes generadas falsas' | |
| article='Proyecto Saturdays.AI DemoDay 11/06/2022' | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.inputs.Image(label="Input Image", type="pil"), | |
| outputs=[ | |
| gr.outputs.Label(label="Class"), | |
| gr.outputs.Image(label="Face") | |
| ], | |
| title=title,description=description, article=article, | |
| theme='peach', | |
| #examples=[examples[i]["path"] for i in range(8)] # fake examples | |
| examples=['fake_frame_0.jpg','fake_frame_1.jpg','fake_frame_2.jpg','fake_frame_3.jpg','real_frame_0.jpg','real_frame_1.jpg','real_frame_2.jpg','real_frame_3.jpg'] | |
| ).launch() |