Spaces:
Paused
Paused
| # +------------------------+ | |
| # | |
| # LLM Guard | |
| # https://llm-guard.com/ | |
| # | |
| # +------------------------+ | |
| # Thank you users! We ❤️ you! - Krrish & Ishaan | |
| ## This provides an LLM Guard Integration for content moderation on the proxy | |
| from typing import Literal, Optional | |
| import aiohttp | |
| from fastapi import HTTPException | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.integrations.custom_logger import CustomLogger | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.secret_managers.main import get_secret_str | |
| from litellm.utils import get_formatted_prompt | |
| litellm.set_verbose = True | |
| class _ENTERPRISE_LLMGuard(CustomLogger): | |
| # Class variables or attributes | |
| def __init__( | |
| self, | |
| mock_testing: bool = False, | |
| mock_redacted_text: Optional[dict] = None, | |
| ): | |
| self.mock_redacted_text = mock_redacted_text | |
| self.llm_guard_mode = litellm.llm_guard_mode | |
| if mock_testing is True: # for testing purposes only | |
| return | |
| self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None) | |
| if self.llm_guard_api_base is None: | |
| raise Exception("Missing `LLM_GUARD_API_BASE` from environment") | |
| elif not self.llm_guard_api_base.endswith("/"): | |
| self.llm_guard_api_base += "/" | |
| def print_verbose(self, print_statement): | |
| try: | |
| verbose_proxy_logger.debug(print_statement) | |
| if litellm.set_verbose: | |
| print(print_statement) # noqa | |
| except Exception: | |
| pass | |
| async def moderation_check(self, text: str): | |
| """ | |
| [TODO] make this more performant for high-throughput scenario | |
| """ | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| if self.mock_redacted_text is not None: | |
| redacted_text = self.mock_redacted_text | |
| else: | |
| # Make the first request to /analyze | |
| analyze_url = f"{self.llm_guard_api_base}analyze/prompt" | |
| verbose_proxy_logger.debug("Making request to: %s", analyze_url) | |
| analyze_payload = {"prompt": text} | |
| redacted_text = None | |
| async with session.post( | |
| analyze_url, json=analyze_payload | |
| ) as response: | |
| redacted_text = await response.json() | |
| verbose_proxy_logger.info( | |
| f"LLM Guard: Received response - {redacted_text}" | |
| ) | |
| if redacted_text is not None: | |
| if ( | |
| redacted_text.get("is_valid", None) is not None | |
| and redacted_text["is_valid"] is False | |
| ): | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "Violated content safety policy"}, | |
| ) | |
| else: | |
| pass | |
| else: | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": f"Invalid content moderation response: {redacted_text}" | |
| }, | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.exception( | |
| "litellm.enterprise.enterprise_hooks.llm_guard::moderation_check - Exception occurred - {}".format( | |
| str(e) | |
| ) | |
| ) | |
| raise e | |
| def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool: | |
| if self.llm_guard_mode == "key-specific": | |
| # check if llm guard enabled for specific keys only | |
| self.print_verbose( | |
| f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" | |
| ) | |
| if ( | |
| user_api_key_dict.permissions.get("enable_llm_guard_check", False) | |
| is True | |
| ): | |
| return True | |
| elif self.llm_guard_mode == "all": | |
| return True | |
| elif self.llm_guard_mode == "request-specific": | |
| self.print_verbose(f"received metadata: {data.get('metadata', {})}") | |
| metadata = data.get("metadata", {}) | |
| permissions = metadata.get("permissions", {}) | |
| if ( | |
| "enable_llm_guard_check" in permissions | |
| and permissions["enable_llm_guard_check"] is True | |
| ): | |
| return True | |
| return False | |
| async def async_moderation_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| call_type: Literal[ | |
| "completion", | |
| "embeddings", | |
| "image_generation", | |
| "moderation", | |
| "audio_transcription", | |
| "responses", | |
| ], | |
| ): | |
| """ | |
| - Calls the LLM Guard Endpoint | |
| - Rejects request if it fails safety check | |
| - Use the sanitized prompt returned | |
| - LLM Guard can handle things like PII Masking, etc. | |
| """ | |
| self.print_verbose( | |
| f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" | |
| ) | |
| _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data) | |
| if _proceed is False: | |
| return | |
| self.print_verbose("Makes LLM Guard Check") | |
| try: | |
| assert call_type in [ | |
| "completion", | |
| "embeddings", | |
| "image_generation", | |
| "moderation", | |
| "audio_transcription", | |
| ] | |
| except Exception: | |
| self.print_verbose( | |
| f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" | |
| ) | |
| return data | |
| formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore | |
| self.print_verbose(f"LLM Guard, formatted_prompt: {formatted_prompt}") | |
| return await self.moderation_check(text=formatted_prompt) | |
| async def async_post_call_streaming_hook( | |
| self, user_api_key_dict: UserAPIKeyAuth, response: str | |
| ): | |
| if response is not None: | |
| await self.moderation_check(text=response) | |
| return response | |
| # llm_guard = _ENTERPRISE_LLMGuard() | |
| # asyncio.run( | |
| # llm_guard.async_moderation_hook( | |
| # data={"messages": [{"role": "user", "content": "Hey how's it going?"}]} | |
| # ) | |
| # ) | |