Spaces:
Running
Running
| # Guardrail.py | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from functools import lru_cache | |
| from transformers import logging as hf_logging | |
| hf_logging.set_verbosity_error() | |
| from transformers import pipeline | |
| SAFE_LABELS = ["pertanyaan sejarah", "pertanyaan olahraga", "pertanyaan alam"] | |
| UNSAFE_LABELS = ["kasar", "penghinaan", "berbahaya"] | |
| def _clf(): | |
| # device=-1 => CPU, model otomatis pakai cache dari prepare_assets.py | |
| return pipeline("zero-shot-classification", | |
| model="joeddav/xlm-roberta-large-xnli", | |
| device=-1) | |
| def classify_text(text: str): | |
| clf = _clf() | |
| labels = SAFE_LABELS + UNSAFE_LABELS | |
| res = clf(text, candidate_labels=labels) | |
| scores = dict(zip(res["labels"], res["scores"])) | |
| return res["labels"][0], res["scores"][0], scores | |
| def validate_input(text: str, threshold: float = 0.2) -> bool: | |
| text = (text or "").strip() | |
| if not text: | |
| return False | |
| top_label, top_score, _ = classify_text(text) | |
| return bool(top_label in SAFE_LABELS and top_score > threshold) | |
| if __name__ == "__main__": | |
| print(validate_input("kapan belanda menjajah indonesia?")) | |