Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import warnings | |
| from huggingface_hub import spaces | |
| # Suppress all warnings | |
| warnings.filterwarnings("ignore") | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
| # Initialize GPU for Hugging Face Spaces | |
| def init_gpu(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Initialize model and tokenizer | |
| MODEL_NAME = "s-nlp/roberta-base-formality-ranker" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| # Move model to GPU | |
| device = init_gpu() | |
| model = model.to(device) | |
| app = FastAPI(title="Formality Classifier API") | |
| class TextInput(BaseModel): | |
| text: str | |
| def calculate_formality_percentages(score): | |
| # Convert score to grayscale percentage (0-100) | |
| grayscale = int(score * 100) | |
| # Use grayscale to determine formal/informal percentages | |
| formal_percent = grayscale | |
| informal_percent = 100 - grayscale | |
| return formal_percent, informal_percent | |
| async def home(): | |
| return {"message": "Formality Classifier API is running! Use /predict to classify text."} | |
| async def predict_formality(input_data: TextInput): | |
| try: | |
| # Tokenize input | |
| encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True) | |
| encoding = {k: v.to(device) for k, v in encoding.items()} | |
| # Predict formality score | |
| with torch.no_grad(): | |
| logits = model(**encoding).logits | |
| score = logits.softmax(dim=1)[:, 1].item() | |
| # Calculate percentages using grayscale | |
| formal_percent, informal_percent = calculate_formality_percentages(score) | |
| # Create response in the new format | |
| response = { | |
| "formality_score": round(score, 3), | |
| "formal_percent": formal_percent, | |
| "informal_percent": informal_percent, | |
| "classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal." | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |