File size: 3,287 Bytes
fa1e508
 
 
 
e6f2697
fa1e508
 
83f964b
fa1e508
 
5a88e29
fa1e508
 
 
 
5a88e29
fa1e508
 
 
cb94f9b
 
 
fa1e508
ebe0afd
4205025
 
fa1e508
 
 
 
e6f2697
 
 
fa1e508
cc0b700
 
 
fa1e508
cc0b700
fa1e508
4205025
e6f2697
fa1e508
4205025
fa1e508
 
cb94f9b
fa1e508
 
 
cc0b700
fa1e508
e6f2697
 
 
 
 
 
 
 
 
 
 
fa1e508
 
 
 
2866e65
e6f2697
 
 
 
 
 
 
 
8ef1a8f
 
e6f2697
 
8ef1a8f
e6f2697
 
 
 
 
 
fa1e508
8ef1a8f
e6f2697
cb94f9b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from pathlib import Path
import json
import torch
from typing import List, Dict
from collections import defaultdict

from huggingface_hub import hf_hub_download
from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention


MODEL_REPO = "TimoBertram/MTG_Model_FIN"
CFG_FILE = "config.json"
MODEL_FILE = "network.pt"


DATA_REPO = "TimoBertram/MTG_Drafting_Dataset"
CARD_FILE = "cards_eoe.json"
ENCODING_FILE = "card_encodings.pt"




class DraftModel:
    def __init__(self):
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cpu")  # Force CPU for compatibility

        weight_path = hf_hub_download(
            repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
        )
        cfg_path = hf_hub_download(
            repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
        )

        with open(cfg_path, "r") as f:
            cfg = json.load(f)

        # ---- load network ---------------------------------------------------


        self.net = MLP_CrossAttention(**cfg)
        self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
        self.net.eval()
        self.net.to(self.device)

        # ---- embeddings – one-time load ------------------------------------
        self.embed_dict = get_embedding_dict(
            hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
            add_nontransformed=True
        )
        self.emb_size = next(iter(self.embed_dict.values())).shape[0]

        raw_card_file = json.load(open(hf_hub_download(
            repo_id=DATA_REPO, filename=CARD_FILE, repo_type="dataset"
        )))
        self.cards = defaultdict(dict)
        for card in raw_card_file:
            self.cards[card["set"]][card["name"]] = card
        
    
    def _embed(self, name):  # helper
        return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]

    # --------------------------------------------------------------------- #
    # Public API expected by streamlit_app.py                               #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def predict(self, pack: List[str], deck: List[str]) -> Dict:

        card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device)
        if deck is None:
            deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
        else:
            deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device)

        vals = self.net(deck = deck_t, cards = card_t)
        scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
        logits = vals.squeeze(0).cpu().numpy()
        return {
            "pick": pack[scores.argmax()],
            "logits": logits.tolist(),
            "scores": scores.tolist(),
        }
    @torch.no_grad()
    def get_p1p1(self, set_code:str):
        keys = list(self.cards[set_code].keys())
        cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device)

        vals = self.predict(pack=keys, deck=None)["logits"]
        return keys, vals