Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from transformers import AutoModelForAudioClassification, AutoFeatureExtractor | |
| import librosa | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import os | |
| from functools import lru_cache | |
| app = FastAPI(title="Speech Emotion Recognition API") | |
| # Global variables for model caching | |
| model = None | |
| feature_extractor = None | |
| id2label = None | |
| def load_model(): | |
| """Load model once and cache it for CPU optimization""" | |
| global model, feature_extractor, id2label | |
| model_id = "firdhokk/speech-emotion-recognition-with-openai-whisper-large-v3" | |
| # Force CPU usage for free tier | |
| device = "cpu" | |
| torch.set_num_threads(2) # Optimize for free CPU | |
| model = AutoModelForAudioClassification.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| device_map="cpu" | |
| ) | |
| feature_extractor = AutoFeatureExtractor.from_pretrained( | |
| model_id, | |
| do_normalize=True | |
| ) | |
| id2label = model.config.id2label | |
| return model, feature_extractor, id2label | |
| def preprocess_audio(audio_path, feature_extractor, max_duration=30.0): | |
| """Preprocess audio with memory optimization""" | |
| audio_array, sampling_rate = librosa.load( | |
| audio_path, | |
| sr=feature_extractor.sampling_rate, | |
| duration=max_duration # Limit duration for CPU efficiency | |
| ) | |
| max_length = int(feature_extractor.sampling_rate * max_duration) | |
| if len(audio_array) > max_length: | |
| audio_array = audio_array[:max_length] | |
| else: | |
| audio_array = np.pad(audio_array, (0, max_length - len(audio_array))) | |
| inputs = feature_extractor( | |
| audio_array, | |
| sampling_rate=feature_extractor.sampling_rate, | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| return inputs | |
| async def startup_event(): | |
| """Load model on startup""" | |
| load_model() | |
| async def predict_emotion(file: UploadFile = File(...)): | |
| """Predict emotion from uploaded audio file""" | |
| try: | |
| # Validate file type | |
| if not file.filename.lower().endswith(('.wav', '.mp3', '.m4a', '.flac')): | |
| raise HTTPException(status_code=400, detail="Unsupported audio format") | |
| # Save uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: | |
| content = await file.read() | |
| tmp_file.write(content) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Load cached model | |
| model, feature_extractor, id2label = load_model() | |
| # Preprocess and predict | |
| inputs = preprocess_audio(tmp_file_path, feature_extractor) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_id = torch.argmax(logits, dim=-1).item() | |
| predicted_label = id2label[predicted_id] | |
| # Get confidence scores | |
| probabilities = torch.softmax(logits, dim=-1) | |
| confidence = probabilities[0][predicted_id].item() | |
| return { | |
| "predicted_emotion": predicted_label, | |
| "confidence": round(confidence, 4), | |
| "all_emotions": { | |
| id2label[i]: round(probabilities[0][i].item(), 4) | |
| for i in range(len(id2label)) | |
| } | |
| } | |
| finally: | |
| # Clean up temporary file | |
| os.unlink(tmp_file_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "model_loaded": model is not None} | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Speech Emotion Recognition API", | |
| "model": "Whisper Large V3", | |
| "emotions": ["Angry", "Disgust", "Fearful", "Happy", "Neutral", "Sad", "Surprised"], | |
| "endpoints": { | |
| "predict": "/predict-emotion", | |
| "health": "/health" | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |