Rag-dan-Guardrail / Guardrail.py
Dyraa18's picture
Upload 7 files
20de58b verified
# Guardrail.py
import warnings
warnings.filterwarnings("ignore")
from functools import lru_cache
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
from transformers import pipeline
SAFE_LABELS = ["pertanyaan sejarah", "pertanyaan olahraga", "pertanyaan alam"]
UNSAFE_LABELS = ["kasar", "penghinaan", "berbahaya"]
@lru_cache(maxsize=1)
def _clf():
# device=-1 => CPU, model otomatis pakai cache dari prepare_assets.py
return pipeline("zero-shot-classification",
model="joeddav/xlm-roberta-large-xnli",
device=-1)
def classify_text(text: str):
clf = _clf()
labels = SAFE_LABELS + UNSAFE_LABELS
res = clf(text, candidate_labels=labels)
scores = dict(zip(res["labels"], res["scores"]))
return res["labels"][0], res["scores"][0], scores
def validate_input(text: str, threshold: float = 0.2) -> bool:
text = (text or "").strip()
if not text:
return False
top_label, top_score, _ = classify_text(text)
return bool(top_label in SAFE_LABELS and top_score > threshold)
if __name__ == "__main__":
print(validate_input("kapan belanda menjajah indonesia?"))