MReq commited on
Commit
a04f1d4
·
verified ·
1 Parent(s): 0d5a23f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +151 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViTImageProcessor, ViTForImageClassification
3
+ from transformers import AutoImageProcessor, SiglipForImageClassification
4
+ from PIL import Image
5
+ import torch
6
+ from fastapi import FastAPI, UploadFile, File
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ import uvicorn
10
+ import io
11
+ import logging
12
+
13
+ # ----------------- LOGGER SETUP -----------------
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
+ logger = logging.getLogger("face-analysis")
16
+
17
+ # ----------------- LOAD MODELS -----------------
18
+ # Emotion model
19
+ emotion_processor = ViTImageProcessor.from_pretrained("abhilash88/face-emotion-detection")
20
+ emotion_model = ViTForImageClassification.from_pretrained("abhilash88/face-emotion-detection")
21
+
22
+ # Age model
23
+ age_model_name = "prithivMLmods/facial-age-detection"
24
+ age_model = SiglipForImageClassification.from_pretrained(age_model_name)
25
+ age_processor = AutoImageProcessor.from_pretrained(age_model_name)
26
+
27
+ # Emotion classes
28
+ emotions = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]
29
+
30
+ # Age labels
31
+ id2label = {
32
+ "0": "age 01-10",
33
+ "1": "age 11-20",
34
+ "2": "age 21-30",
35
+ "3": "age 31-40",
36
+ "4": "age 41-55",
37
+ "5": "age 56-65",
38
+ "6": "age 66-80",
39
+ "7": "age 80+"
40
+ }
41
+
42
+ # ----------------- PREDICT FUNCTIONS -----------------
43
+ def predict_emotion(image: Image.Image):
44
+ try:
45
+ inputs = emotion_processor(image.convert("RGB"), return_tensors="pt")
46
+ with torch.no_grad():
47
+ outputs = emotion_model(**inputs)
48
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
49
+ idx = torch.argmax(probs).item()
50
+
51
+ result = {
52
+ "predicted_emotion": emotions[idx],
53
+ "confidence": round(probs[idx].item(), 4),
54
+ "all_confidences": {emotions[i]: float(probs[i]) for i in range(len(emotions))}
55
+ }
56
+
57
+ logger.info(f"Predicted Emotion: {result['predicted_emotion']} (Confidence: {result['confidence']})")
58
+ return result
59
+ except Exception as e:
60
+ logger.error(f"Emotion prediction error: {e}")
61
+ return {"error": str(e)}
62
+
63
+ def predict_age(image: Image.Image):
64
+ try:
65
+ inputs = age_processor(images=image.convert("RGB"), return_tensors="pt")
66
+ with torch.no_grad():
67
+ outputs = age_model(**inputs)
68
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist()
69
+ prediction = {id2label[str(i)]: round(probs[i], 3) for i in range(len(probs))}
70
+ idx = int(torch.argmax(torch.tensor(probs)))
71
+
72
+ result = {
73
+ "predicted_age": id2label[str(idx)],
74
+ "confidence": round(probs[idx], 4),
75
+ "all_confidences": prediction
76
+ }
77
+
78
+ logger.info(f"Predicted Age Group: {result['predicted_age']} (Confidence: {result['confidence']})")
79
+ return result
80
+ except Exception as e:
81
+ logger.error(f"Age prediction error: {e}")
82
+ return {"error": str(e)}
83
+
84
+ # ----------------- FASTAPI APP -----------------
85
+ app = FastAPI()
86
+
87
+ app.add_middleware(
88
+ CORSMiddleware,
89
+ allow_origins=["*"],
90
+ allow_credentials=True,
91
+ allow_methods=["*"],
92
+ allow_headers=["*"],
93
+ )
94
+
95
+ @app.get("/health")
96
+ async def health():
97
+ return {"status": "ok"}
98
+
99
+ @app.post("/predict")
100
+ async def predict(file: UploadFile = File(...)):
101
+ try:
102
+ contents = await file.read()
103
+ image = Image.open(io.BytesIO(contents))
104
+
105
+ emotion_result = predict_emotion(image)
106
+ age_result = predict_age(image)
107
+
108
+ logger.info(f"API Response -> Emotion: {emotion_result.get('predicted_emotion')} | Age: {age_result.get('predicted_age')}")
109
+
110
+ return JSONResponse(content={
111
+ "emotion": emotion_result,
112
+ "age": age_result
113
+ })
114
+ except Exception as e:
115
+ logger.error(f"API Error: {e}")
116
+ return JSONResponse(content={"error": str(e)}, status_code=500)
117
+
118
+ # ----------------- GRADIO DEMO -----------------
119
+ def gradio_wrapper(image):
120
+ emotion_result = predict_emotion(image)
121
+ age_result = predict_age(image)
122
+
123
+ if "error" in emotion_result or "error" in age_result:
124
+ return "Error", {}, "Error", {}
125
+
126
+ return (
127
+ f"{emotion_result['predicted_emotion']} ({emotion_result['confidence']:.2f})",
128
+ emotion_result["all_confidences"],
129
+ f"{age_result['predicted_age']} ({age_result['confidence']:.2f})",
130
+ age_result["all_confidences"]
131
+ )
132
+
133
+ demo = gr.Interface(
134
+ fn=gradio_wrapper,
135
+ inputs=gr.Image(type="pil"),
136
+ outputs=[
137
+ gr.Label(num_top_classes=1, label="Top Emotion"),
138
+ gr.Label(label="Emotion Probabilities"),
139
+ gr.Label(num_top_classes=1, label="Top Age Group"),
140
+ gr.Label(label="Age Probabilities"),
141
+ ],
142
+ title="Face Emotion + Age Detection",
143
+ description="Upload a face image and detect both emotion (Angry, Happy, etc.) and estimated age group (01–10, 11–20, ... 80+)."
144
+ )
145
+
146
+ # Mount Gradio at /gradio
147
+ app = gr.mount_gradio_app(app, demo, path="/gradio")
148
+
149
+ # ----------------- RUN -----------------
150
+ if __name__ == "__main__":
151
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ pillow