Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import torch | |
| import weave | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| from transformers.pipelines.base import Pipeline | |
| import wandb | |
| from ..base import Guardrail | |
| class PromptInjectionClassifierGuardrail(Guardrail): | |
| """ | |
| A guardrail that uses a pre-trained text-classification model to classify prompts | |
| for potential injection attacks. | |
| Args: | |
| model_name (str): The name of the HuggingFace model or a WandB | |
| checkpoint artifact path to use for classification. | |
| """ | |
| model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2" | |
| _classifier: Optional[Pipeline] = None | |
| def model_post_init(self, __context): | |
| if self.model_name.startswith("wandb://"): | |
| api = wandb.Api() | |
| artifact = api.artifact(self.model_name.removeprefix("wandb://")) | |
| artifact_dir = artifact.download() | |
| tokenizer = AutoTokenizer.from_pretrained(artifact_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(artifact_dir) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(self.model_name) | |
| self._classifier = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| truncation=True, | |
| max_length=512, | |
| device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| ) | |
| def classify(self, prompt: str): | |
| return self._classifier(prompt) | |
| def guard(self, prompt: str): | |
| """ | |
| Analyzes the given prompt to determine if it is safe or potentially an injection attack. | |
| This function uses a pre-trained text-classification model to classify the prompt. | |
| It calls the `classify` method to get the classification result, which includes a label | |
| and a confidence score. The function then calculates the confidence percentage and | |
| returns a dictionary with two keys: | |
| - "safe": A boolean indicating whether the prompt is safe (True) or an injection (False). | |
| - "summary": A string summarizing the classification result, including the label and the | |
| confidence percentage. | |
| Args: | |
| prompt (str): The input prompt to be classified. | |
| Returns: | |
| dict: A dictionary containing the safety status and a summary of the classification result. | |
| """ | |
| response = self.classify(prompt) | |
| confidence_percentage = round(response[0]["score"] * 100, 2) | |
| return { | |
| "safe": response[0]["label"] != "INJECTION", | |
| "summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.", | |
| } | |
| def predict(self, prompt: str): | |
| return self.guard(prompt) | |