Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import json | |
| import torch | |
| from typing import List, Dict | |
| from huggingface_hub import hf_hub_download | |
| from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention | |
| MODEL_REPO = "TimoBertram/MTG_Model_FIN" | |
| CFG_FILE = "config.json" | |
| MODEL_FILE = "network.pt" | |
| DATA_REPO = "TimoBertram/MTG_Drafting_Dataset" | |
| CARD_FILE = "cards_eoe.json" | |
| ENCODING_FILE = "card_encodings.pt" | |
| class DraftModel: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| cfg_path = hf_hub_download( | |
| repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model" | |
| ) | |
| weight_path = hf_hub_download( | |
| repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model" | |
| ) | |
| # ---- load network --------------------------------------------------- | |
| cfg = open(cfg_path, "r") | |
| cfg.pop("name", None) | |
| self.net = MLP_CrossAttention(**cfg).to(self.device) | |
| self.net.load_state_dict(weight_path, map_location=self.device) | |
| self.net.eval() | |
| # ---- embeddings β one-time load ------------------------------------ | |
| self.embed_dict = get_embedding_dict( | |
| hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"), | |
| add_nontransformed=True | |
| ) | |
| # --------------------------------------------------------------------- # | |
| # Public API expected by streamlit_app.py # | |
| # --------------------------------------------------------------------- # | |
| def predict(self, pack: List[Dict], picks: List[Dict], deck: List[Dict]) -> Dict: | |
| names = [c["name"] for c in pack] | |
| def embed(name): # helper | |
| return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0] | |
| card_t = torch.stack([embed(n) for n in names]).unsqueeze(0).to(self.device) | |
| deck_t = torch.zeros((1, 45, self.emb_size), device=self.device) | |
| return torch.softmax(self.net(card_t, deck_t), dim=1).squeeze(0).cpu().numpy().tolist() | |