| import argparse | |
| import base64 | |
| import io | |
| import os | |
| import pickle | |
| import requests | |
| import torch | |
| import uvicorn | |
| from fastapi import FastAPI | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"USING DEVICE: {DEVICE}") | |
| tokenizer = AutoTokenizer.from_pretrained("xTRam1/safe-guard-classifier") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "xTRam1/safe-guard-classifier" | |
| ) | |
| classifier = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| truncation=True, | |
| max_length=512, | |
| device=torch.device(DEVICE), | |
| ) | |
| app = FastAPI() | |
| async def generate(request: dict): | |
| input = request["text"] | |
| print("INPUT:", input) | |
| result = classifier(input) | |
| print("RESULT:", result) | |
| return JSONResponse(content={"text": input, "result": result}) | |
| if __name__ == "__main__": | |
| # print("here") | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=8000) | |
| args = parser.parse_args() | |
| port = args.port | |
| uvicorn.run(app, host="127.0.0.1", port=port) | |