# +------------------------+ # # LLM Guard # https://llm-guard.com/ # # +------------------------+ # Thank you users! We ❤️ you! - Krrish & Ishaan ## This provides an LLM Guard Integration for content moderation on the proxy from typing import Optional, Literal, Union import litellm import traceback import sys import uuid import os from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger from litellm.utils import ( ModelResponse, EmbeddingResponse, ImageResponse, StreamingChoices, ) from datetime import datetime import aiohttp import asyncio from litellm.utils import get_formatted_prompt from litellm.secret_managers.main import get_secret_str litellm.set_verbose = True class _ENTERPRISE_LLMGuard(CustomLogger): # Class variables or attributes def __init__( self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None, ): self.mock_redacted_text = mock_redacted_text self.llm_guard_mode = litellm.llm_guard_mode if mock_testing == True: # for testing purposes only return self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None) if self.llm_guard_api_base is None: raise Exception("Missing `LLM_GUARD_API_BASE` from environment") elif not self.llm_guard_api_base.endswith("/"): self.llm_guard_api_base += "/" def print_verbose(self, print_statement): try: verbose_proxy_logger.debug(print_statement) if litellm.set_verbose: print(print_statement) # noqa except Exception: pass async def moderation_check(self, text: str): """ [TODO] make this more performant for high-throughput scenario """ try: async with aiohttp.ClientSession() as session: if self.mock_redacted_text is not None: redacted_text = self.mock_redacted_text else: # Make the first request to /analyze analyze_url = f"{self.llm_guard_api_base}analyze/prompt" verbose_proxy_logger.debug("Making request to: %s", analyze_url) analyze_payload = {"prompt": text} redacted_text = None async with session.post( analyze_url, json=analyze_payload ) as response: redacted_text = await response.json() verbose_proxy_logger.info( f"LLM Guard: Received response - {redacted_text}" ) if redacted_text is not None: if ( redacted_text.get("is_valid", None) is not None and redacted_text["is_valid"] != True ): raise HTTPException( status_code=400, detail={"error": "Violated content safety policy"}, ) else: pass else: raise HTTPException( status_code=500, detail={ "error": f"Invalid content moderation response: {redacted_text}" }, ) except Exception as e: verbose_proxy_logger.exception( "litellm.enterprise.enterprise_hooks.llm_guard::moderation_check - Exception occurred - {}".format( str(e) ) ) raise e def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool: if self.llm_guard_mode == "key-specific": # check if llm guard enabled for specific keys only self.print_verbose( f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" ) if ( user_api_key_dict.permissions.get("enable_llm_guard_check", False) == True ): return True elif self.llm_guard_mode == "all": return True elif self.llm_guard_mode == "request-specific": self.print_verbose(f"received metadata: {data.get('metadata', {})}") metadata = data.get("metadata", {}) permissions = metadata.get("permissions", {}) if ( "enable_llm_guard_check" in permissions and permissions["enable_llm_guard_check"] == True ): return True return False async def async_moderation_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", ], ): """ - Calls the LLM Guard Endpoint - Rejects request if it fails safety check - Use the sanitized prompt returned - LLM Guard can handle things like PII Masking, etc. """ self.print_verbose( f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" ) _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data) if _proceed == False: return self.print_verbose("Makes LLM Guard Check") try: assert call_type in [ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", ] except Exception as e: self.print_verbose( f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" ) return data formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore self.print_verbose(f"LLM Guard, formatted_prompt: {formatted_prompt}") return await self.moderation_check(text=formatted_prompt) async def async_post_call_streaming_hook( self, user_api_key_dict: UserAPIKeyAuth, response: str ): if response is not None: await self.moderation_check(text=response) return response # llm_guard = _ENTERPRISE_LLMGuard() # asyncio.run( # llm_guard.async_moderation_hook( # data={"messages": [{"role": "user", "content": "Hey how's it going?"}]} # ) # )