File size: 1,934 Bytes
5fc69e4
 
 
 
 
 
 
 
 
57578cb
5fc69e4
57578cb
 
5fc69e4
 
 
 
 
 
 
 
57578cb
5fc69e4
57578cb
 
5fc69e4
 
 
 
 
 
 
 
57578cb
5fc69e4
57578cb
5fc69e4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM
)
from sentence_transformers import SentenceTransformer


def load_emotion_model(model_name: str, model_dir: Path, token: str = None):
    if not model_dir.exists() or not any(model_dir.iterdir()):
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
        model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
        tokenizer.save_pretrained(model_dir)
        model.save_pretrained(model_dir)

    tokenizer = AutoTokenizer.from_pretrained(str(model_dir), trust_remote_code=True, local_files_only=True)
    model = AutoModelForSequenceClassification.from_pretrained(str(model_dir), trust_remote_code=True, local_files_only=True)
    return tokenizer, model


def load_fallback_model(model_name: str, model_dir: Path, token: str = None):
    if not model_dir.exists() or not any(model_dir.iterdir()):
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
        model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, use_auth_token=token)
        tokenizer.save_pretrained(model_dir)
        model.save_pretrained(model_dir)

    tokenizer = AutoTokenizer.from_pretrained(str(model_dir), trust_remote_code=True, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(str(model_dir), trust_remote_code=True, local_files_only=True)
    return tokenizer, model


def load_embedder(model_name: str, model_dir: Path, token: str = None):
    if not model_dir.exists() or not any(model_dir.iterdir()):
        embedder = SentenceTransformer(model_name, use_auth_token=token)
        embedder.save(str(model_dir))

    embedder = SentenceTransformer(str(model_dir))
    return embedder