import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from sentence_transformers import SentenceTransformer import torch # Define the API schema for the request body class EmbeddingRequest(BaseModel): inputs: list[str] # Initialize FastAPI app app = FastAPI() # Check for GPU and load model accordingly device = "cuda" if torch.cuda.is_available() else "cpu" model_name = os.getenv("MODEL_NAME") model = SentenceTransformer(model_name, device=device, trust_remote_code=True) # Define the embedding endpoint @app.post("/embed") async def get_embeddings(request: EmbeddingRequest): try: # Get embeddings for the input text embeddings = model.encode(request.inputs, convert_to_numpy=True).tolist() return embeddings except Exception as e: raise HTTPException(status_code=500, detail=str(e))