Spaces:
Sleeping
Sleeping
File size: 3,287 Bytes
fa1e508 e6f2697 fa1e508 83f964b fa1e508 5a88e29 fa1e508 5a88e29 fa1e508 cb94f9b fa1e508 ebe0afd 4205025 fa1e508 e6f2697 fa1e508 cc0b700 fa1e508 cc0b700 fa1e508 4205025 e6f2697 fa1e508 4205025 fa1e508 cb94f9b fa1e508 cc0b700 fa1e508 e6f2697 fa1e508 2866e65 e6f2697 8ef1a8f e6f2697 8ef1a8f e6f2697 fa1e508 8ef1a8f e6f2697 cb94f9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from pathlib import Path
import json
import torch
from typing import List, Dict
from collections import defaultdict
from huggingface_hub import hf_hub_download
from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention
MODEL_REPO = "TimoBertram/MTG_Model_FIN"
CFG_FILE = "config.json"
MODEL_FILE = "network.pt"
DATA_REPO = "TimoBertram/MTG_Drafting_Dataset"
CARD_FILE = "cards_eoe.json"
ENCODING_FILE = "card_encodings.pt"
class DraftModel:
def __init__(self):
#self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cpu") # Force CPU for compatibility
weight_path = hf_hub_download(
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
)
cfg_path = hf_hub_download(
repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
)
with open(cfg_path, "r") as f:
cfg = json.load(f)
# ---- load network ---------------------------------------------------
self.net = MLP_CrossAttention(**cfg)
self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
self.net.eval()
self.net.to(self.device)
# ---- embeddings – one-time load ------------------------------------
self.embed_dict = get_embedding_dict(
hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
add_nontransformed=True
)
self.emb_size = next(iter(self.embed_dict.values())).shape[0]
raw_card_file = json.load(open(hf_hub_download(
repo_id=DATA_REPO, filename=CARD_FILE, repo_type="dataset"
)))
self.cards = defaultdict(dict)
for card in raw_card_file:
self.cards[card["set"]][card["name"]] = card
def _embed(self, name): # helper
return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
# --------------------------------------------------------------------- #
# Public API expected by streamlit_app.py #
# --------------------------------------------------------------------- #
@torch.no_grad()
def predict(self, pack: List[str], deck: List[str]) -> Dict:
card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device)
if deck is None:
deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
else:
deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device)
vals = self.net(deck = deck_t, cards = card_t)
scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
logits = vals.squeeze(0).cpu().numpy()
return {
"pick": pack[scores.argmax()],
"logits": logits.tolist(),
"scores": scores.tolist(),
}
@torch.no_grad()
def get_p1p1(self, set_code:str):
keys = list(self.cards[set_code].keys())
cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device)
vals = self.predict(pack=keys, deck=None)["logits"]
return keys, vals
|