import gradio as gr from transformers import ViTImageProcessor, ViTForImageClassification from transformers import AutoImageProcessor, SiglipForImageClassification from PIL import Image import torch from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import io import logging # ----------------- LOGGER SETUP ----------------- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger("face-analysis") # ----------------- LOAD MODELS ----------------- # Emotion model emotion_processor = ViTImageProcessor.from_pretrained("abhilash88/face-emotion-detection") emotion_model = ViTForImageClassification.from_pretrained("abhilash88/face-emotion-detection") # Age model age_model_name = "prithivMLmods/facial-age-detection" age_model = SiglipForImageClassification.from_pretrained(age_model_name) age_processor = AutoImageProcessor.from_pretrained(age_model_name) # Emotion classes emotions = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"] # Age labels id2label = { "0": "age 01-10", "1": "age 11-20", "2": "age 21-30", "3": "age 31-40", "4": "age 41-55", "5": "age 56-65", "6": "age 66-80", "7": "age 80+" } # ----------------- PREDICT FUNCTIONS ----------------- def predict_emotion(image: Image.Image): try: inputs = emotion_processor(image.convert("RGB"), return_tensors="pt") with torch.no_grad(): outputs = emotion_model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] idx = torch.argmax(probs).item() result = { "predicted_emotion": emotions[idx], "confidence": round(probs[idx].item(), 4), "all_confidences": {emotions[i]: float(probs[i]) for i in range(len(emotions))} } logger.info(f"Predicted Emotion: {result['predicted_emotion']} (Confidence: {result['confidence']})") return result except Exception as e: logger.error(f"Emotion prediction error: {e}") return {"error": str(e)} def predict_age(image: Image.Image): try: inputs = age_processor(images=image.convert("RGB"), return_tensors="pt") with torch.no_grad(): outputs = age_model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist() prediction = {id2label[str(i)]: round(probs[i], 3) for i in range(len(probs))} idx = int(torch.argmax(torch.tensor(probs))) result = { "predicted_age": id2label[str(idx)], "confidence": round(probs[idx], 4), "all_confidences": prediction } logger.info(f"Predicted Age Group: {result['predicted_age']} (Confidence: {result['confidence']})") return result except Exception as e: logger.error(f"Age prediction error: {e}") return {"error": str(e)} # ----------------- FASTAPI APP ----------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") async def health(): return {"status": "ok"} @app.post("/predict") async def predict(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)) emotion_result = predict_emotion(image) age_result = predict_age(image) logger.info(f"API Response -> Emotion: {emotion_result.get('predicted_emotion')} | Age: {age_result.get('predicted_age')}") return JSONResponse(content={ "emotion": emotion_result, "age": age_result }) except Exception as e: logger.error(f"API Error: {e}") return JSONResponse(content={"error": str(e)}, status_code=500) # ----------------- GRADIO DEMO ----------------- def gradio_wrapper(image): emotion_result = predict_emotion(image) age_result = predict_age(image) if "error" in emotion_result or "error" in age_result: return "Error", {}, "Error", {} return ( f"{emotion_result['predicted_emotion']} ({emotion_result['confidence']:.2f})", emotion_result["all_confidences"], f"{age_result['predicted_age']} ({age_result['confidence']:.2f})", age_result["all_confidences"] ) demo = gr.Interface( fn=gradio_wrapper, inputs=gr.Image(type="pil"), outputs=[ gr.Label(num_top_classes=1, label="Top Emotion"), gr.Label(label="Emotion Probabilities"), gr.Label(num_top_classes=1, label="Top Age Group"), gr.Label(label="Age Probabilities"), ], title="Face Emotion + Age Detection", description="Upload a face image and detect both emotion (Angry, Happy, etc.) and estimated age group (01–10, 11–20, ... 80+)." ) # Mount Gradio at /gradio app = gr.mount_gradio_app(app, demo, path="/gradio") # ----------------- RUN ----------------- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)