MReq commited on
Commit
dc0c3ff
·
verified ·
1 Parent(s): c5f3db9

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -151
app.py DELETED
@@ -1,151 +0,0 @@
1
- import gradio as gr
2
- from transformers import ViTImageProcessor, ViTForImageClassification
3
- from transformers import AutoImageProcessor, SiglipForImageClassification
4
- from PIL import Image
5
- import torch
6
- from fastapi import FastAPI, UploadFile, File
7
- from fastapi.responses import JSONResponse
8
- from fastapi.middleware.cors import CORSMiddleware
9
- import uvicorn
10
- import io
11
- import logging
12
-
13
- # ----------------- LOGGER SETUP -----------------
14
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
- logger = logging.getLogger("face-analysis")
16
-
17
- # ----------------- LOAD MODELS -----------------
18
- # Emotion model
19
- emotion_processor = ViTImageProcessor.from_pretrained("abhilash88/face-emotion-detection")
20
- emotion_model = ViTForImageClassification.from_pretrained("abhilash88/face-emotion-detection")
21
-
22
- # Age model
23
- age_model_name = "prithivMLmods/facial-age-detection"
24
- age_model = SiglipForImageClassification.from_pretrained(age_model_name)
25
- age_processor = AutoImageProcessor.from_pretrained(age_model_name)
26
-
27
- # Emotion classes
28
- emotions = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]
29
-
30
- # Age labels
31
- id2label = {
32
- "0": "age 01-10",
33
- "1": "age 11-20",
34
- "2": "age 21-30",
35
- "3": "age 31-40",
36
- "4": "age 41-55",
37
- "5": "age 56-65",
38
- "6": "age 66-80",
39
- "7": "age 80+"
40
- }
41
-
42
- # ----------------- PREDICT FUNCTIONS -----------------
43
- def predict_emotion(image: Image.Image):
44
- try:
45
- inputs = emotion_processor(image.convert("RGB"), return_tensors="pt")
46
- with torch.no_grad():
47
- outputs = emotion_model(**inputs)
48
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
49
- idx = torch.argmax(probs).item()
50
-
51
- result = {
52
- "predicted_emotion": emotions[idx],
53
- "confidence": round(probs[idx].item(), 4),
54
- "all_confidences": {emotions[i]: float(probs[i]) for i in range(len(emotions))}
55
- }
56
-
57
- logger.info(f"Predicted Emotion: {result['predicted_emotion']} (Confidence: {result['confidence']})")
58
- return result
59
- except Exception as e:
60
- logger.error(f"Emotion prediction error: {e}")
61
- return {"error": str(e)}
62
-
63
- def predict_age(image: Image.Image):
64
- try:
65
- inputs = age_processor(images=image.convert("RGB"), return_tensors="pt")
66
- with torch.no_grad():
67
- outputs = age_model(**inputs)
68
- probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist()
69
- prediction = {id2label[str(i)]: round(probs[i], 3) for i in range(len(probs))}
70
- idx = int(torch.argmax(torch.tensor(probs)))
71
-
72
- result = {
73
- "predicted_age": id2label[str(idx)],
74
- "confidence": round(probs[idx], 4),
75
- "all_confidences": prediction
76
- }
77
-
78
- logger.info(f"Predicted Age Group: {result['predicted_age']} (Confidence: {result['confidence']})")
79
- return result
80
- except Exception as e:
81
- logger.error(f"Age prediction error: {e}")
82
- return {"error": str(e)}
83
-
84
- # ----------------- FASTAPI APP -----------------
85
- app = FastAPI()
86
-
87
- app.add_middleware(
88
- CORSMiddleware,
89
- allow_origins=["*"],
90
- allow_credentials=True,
91
- allow_methods=["*"],
92
- allow_headers=["*"],
93
- )
94
-
95
- @app.get("/health")
96
- async def health():
97
- return {"status": "ok"}
98
-
99
- @app.post("/predict")
100
- async def predict(file: UploadFile = File(...)):
101
- try:
102
- contents = await file.read()
103
- image = Image.open(io.BytesIO(contents))
104
-
105
- emotion_result = predict_emotion(image)
106
- age_result = predict_age(image)
107
-
108
- logger.info(f"API Response -> Emotion: {emotion_result.get('predicted_emotion')} | Age: {age_result.get('predicted_age')}")
109
-
110
- return JSONResponse(content={
111
- "emotion": emotion_result,
112
- "age": age_result
113
- })
114
- except Exception as e:
115
- logger.error(f"API Error: {e}")
116
- return JSONResponse(content={"error": str(e)}, status_code=500)
117
-
118
- # ----------------- GRADIO DEMO -----------------
119
- def gradio_wrapper(image):
120
- emotion_result = predict_emotion(image)
121
- age_result = predict_age(image)
122
-
123
- if "error" in emotion_result or "error" in age_result:
124
- return "Error", {}, "Error", {}
125
-
126
- return (
127
- f"{emotion_result['predicted_emotion']} ({emotion_result['confidence']:.2f})",
128
- emotion_result["all_confidences"],
129
- f"{age_result['predicted_age']} ({age_result['confidence']:.2f})",
130
- age_result["all_confidences"]
131
- )
132
-
133
- demo = gr.Interface(
134
- fn=gradio_wrapper,
135
- inputs=gr.Image(type="pil"),
136
- outputs=[
137
- gr.Label(num_top_classes=1, label="Top Emotion"),
138
- gr.Label(label="Emotion Probabilities"),
139
- gr.Label(num_top_classes=1, label="Top Age Group"),
140
- gr.Label(label="Age Probabilities"),
141
- ],
142
- title="Face Emotion + Age Detection",
143
- description="Upload a face image and detect both emotion (Angry, Happy, etc.) and estimated age group (01–10, 11–20, ... 80+)."
144
- )
145
-
146
- # Mount Gradio at /gradio
147
- app = gr.mount_gradio_app(app, demo, path="/gradio")
148
-
149
- # ----------------- RUN -----------------
150
- if __name__ == "__main__":
151
- uvicorn.run(app, host="0.0.0.0", port=7860)