Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Any | |
| import japanese_clip as ja_clip | |
| from s3_session import Bucket | |
| from PIL import Image | |
| import uuid | |
| from db_session import get_db | |
| class MLModel: | |
| tokenizer: Any = None | |
| model: Any = None | |
| preprocess: Any = None | |
| bucket: Any = None | |
| def __post_init__(self): | |
| tokenizer = ja_clip.load_tokenizer() | |
| model, preprocess = ja_clip.load( | |
| "rinna/japanese-clip-vit-b-16", cache_dir="/tmp/japanese_clip", device="cpu" | |
| ) | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.preprocess = preprocess | |
| self.bucket = Bucket() | |
| def save(self, image_path: str): | |
| pillow_iamge = Image.open(image_path) | |
| image = self.preprocess(pillow_iamge).unsqueeze(0).to("cpu") | |
| image_features = self.model.get_image_features(image) | |
| image_uuid = str(uuid.uuid4()) | |
| # media upload | |
| self.bucket.upload_file(pillow_iamge, image_uuid) | |
| # db insert | |
| db = get_db() | |
| result = db["embedding"].insert_one( | |
| {"uuid": image_uuid, "vectorField": image_features[0].tolist()} | |
| ) | |
| return result.inserted_id | |
| def search(self, prompt: str): | |
| db = get_db() | |
| encodings = ja_clip.tokenize( | |
| texts=[prompt], max_seq_len=77, device="cpu", tokenizer=self.tokenizer | |
| ) | |
| text_features = self.model.get_text_features(**encodings) | |
| pipeline = [ | |
| { | |
| "$vectorSearch": { | |
| "index": "vector_index", | |
| "path": "vectorField", | |
| "queryVector": text_features[0].tolist(), | |
| "numCandidates": 150, | |
| "limit": 10, | |
| } | |
| }, | |
| { | |
| "$project": { | |
| "_id": {"$toString": "$_id"}, | |
| "uuid": 1, | |
| "score": {"$meta": "vectorSearchScore"}, | |
| } | |
| }, | |
| ] | |
| result = db["embedding"].aggregate(pipeline) | |
| urls = [self.bucket.get_presigned_url(x["uuid"]) for x in result] | |
| return urls | |