Spaces:
Sleeping
Sleeping
| from typing import Generic, List, Optional, TypeVar | |
| from functools import partial | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from fastapi import FastAPI | |
| import numpy | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import ORJSONResponse | |
| MODEL = SentenceTransformer("all-mpnet-base-v2") | |
| def cache(func): | |
| inner_cache = dict() | |
| def inner(sentences: List[str]): | |
| if len(sentences) == 0: | |
| return [] | |
| not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences)) | |
| if len(not_in_cache) > 0: | |
| processed_sentences = func(list(not_in_cache)) | |
| for sentence, embedding in zip(not_in_cache, processed_sentences): | |
| inner_cache[sentence] = embedding | |
| return [inner_cache[s] for s in sentences] | |
| return inner | |
| def _encode(sentences: List[str]): | |
| embeddings = MODEL.encode(sentences, normalize_embeddings=True, batch_size=2, show_progress_bar=True) | |
| array = [numpy.around(a, 3).tolist() for a in embeddings] | |
| return array | |
| class EmbedReq(BaseModel): | |
| sentences: List[str] | |
| app = FastAPI() | |
| def embed(embed: EmbedReq): | |
| return _encode(embed.sentences) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) |