File size: 1,179 Bytes
20de58b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Guardrail.py
import warnings
warnings.filterwarnings("ignore")
from functools import lru_cache

from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
from transformers import pipeline

SAFE_LABELS   = ["pertanyaan sejarah", "pertanyaan olahraga", "pertanyaan alam"]
UNSAFE_LABELS = ["kasar", "penghinaan", "berbahaya"]

@lru_cache(maxsize=1)
def _clf():
    # device=-1 => CPU, model otomatis pakai cache dari prepare_assets.py
    return pipeline("zero-shot-classification",
                    model="joeddav/xlm-roberta-large-xnli",
                    device=-1)

def classify_text(text: str):
    clf = _clf()
    labels = SAFE_LABELS + UNSAFE_LABELS
    res = clf(text, candidate_labels=labels)
    scores = dict(zip(res["labels"], res["scores"]))
    return res["labels"][0], res["scores"][0], scores

def validate_input(text: str, threshold: float = 0.2) -> bool:
    text = (text or "").strip()
    if not text:
        return False
    top_label, top_score, _ = classify_text(text)
    return bool(top_label in SAFE_LABELS and top_score > threshold)

if __name__ == "__main__":
    print(validate_input("kapan belanda menjajah indonesia?"))