diff --git a/litellm/proxy/custom_guardrail.py b/litellm/proxy/custom_guardrail.py new file mode 100644 index 0000000000..bdcdcee1cb --- /dev/null +++ b/litellm/proxy/custom_guardrail.py @@ -0,0 +1,115 @@ +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import json +import sys +import traceback +import uuid +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional, Union + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.types.guardrails import GuardrailEventHooks + + +class myCustomGuardrail(CustomGuardrail): + def __init__( + self, + **kwargs, + ): + # store kwargs as optional_params + self.optional_params = kwargs + + super().__init__(**kwargs) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + ], + ) -> Optional[Union[Exception, str, dict]]: + # In this guardrail, if a user inputs `litellm` we will mask it. + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + _content = _content.replace("litellm", "********") + message["content"] = _content + + verbose_proxy_logger.debug( + "async_pre_call_hook: Message after masking %s", _messages + ) + + return data + + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + """ + Runs in parallel to LLM API call + Runs on only Input + """ + + # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call + # In this guardrail, if a user inputs `litellm` we will mask it. + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + _content = _content.replace("litellm", "********") + message["content"] = _content + + verbose_proxy_logger.debug( + "async_pre_call_hook: Message after masking %s", _messages + ) + pass + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + """ + Runs on response from LLM API call + + If a response contains the word "coffee" -> we will raise an exception + """ + verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) + if isinstance(response, litellm.ModelResponse): + for choice in response.choices: + if isinstance(choice, litellm.Choices): + verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice) + if ( + choice.message.content + and isinstance(choice.message.content, str) + and "coffee" in choice.message.content + ): + raise ValueError("Guardrail failed Coffee Detected") diff --git a/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py new file mode 100644 index 0000000000..bdcdcee1cb --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py @@ -0,0 +1,115 @@ +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import json +import sys +import traceback +import uuid +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional, Union + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.types.guardrails import GuardrailEventHooks + + +class myCustomGuardrail(CustomGuardrail): + def __init__( + self, + **kwargs, + ): + # store kwargs as optional_params + self.optional_params = kwargs + + super().__init__(**kwargs) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + ], + ) -> Optional[Union[Exception, str, dict]]: + # In this guardrail, if a user inputs `litellm` we will mask it. + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + _content = _content.replace("litellm", "********") + message["content"] = _content + + verbose_proxy_logger.debug( + "async_pre_call_hook: Message after masking %s", _messages + ) + + return data + + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + """ + Runs in parallel to LLM API call + Runs on only Input + """ + + # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call + # In this guardrail, if a user inputs `litellm` we will mask it. + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + _content = _content.replace("litellm", "********") + message["content"] = _content + + verbose_proxy_logger.debug( + "async_pre_call_hook: Message after masking %s", _messages + ) + pass + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + """ + Runs on response from LLM API call + + If a response contains the word "coffee" -> we will raise an exception + """ + verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) + if isinstance(response, litellm.ModelResponse): + for choice in response.choices: + if isinstance(choice, litellm.Choices): + verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice) + if ( + choice.message.content + and isinstance(choice.message.content, str) + and "coffee" in choice.message.content + ): + raise ValueError("Guardrail failed Coffee Detected") diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 320216a79b..acb792aec9 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,17 +1,19 @@ model_list: - - model_name: fake-openai-endpoint + - model_name: gpt-4 litellm_params: - model: azure/chatgpt-v-2 - api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_version: "2023-05-15" - tenant_id: os.environ/AZURE_TENANT_ID - client_id: os.environ/AZURE_CLIENT_ID - client_secret: os.environ/AZURE_CLIENT_SECRET + model: openai/gpt-4o + api_key: os.environ/OPENAI_API_KEY guardrails: - - guardrail_name: "bedrock-pre-guard" + - guardrail_name: "custom-pre-guard" litellm_params: - guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" - mode: "post_call" - guardrailIdentifier: ff6ujrregl1q - guardrailVersion: "DRAFT" \ No newline at end of file + guardrail: custom_guardrail.myCustomGuardrail + mode: "pre_call" + - guardrail_name: "custom-during-guard" + litellm_params: + guardrail: custom_guardrail.myCustomGuardrail + mode: "during_call" + - guardrail_name: "custom-post-guard" + litellm_params: + guardrail: custom_guardrail.myCustomGuardrail + mode: "post_call" \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index a770177179..7bdad862bc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -30,6 +30,7 @@ from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching import DualCache, RedisCache from litellm.exceptions import RejectedRequestError +from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.slack_alerting import SlackAlerting from litellm.litellm_core_utils.core_helpers import ( @@ -344,6 +345,23 @@ class ProxyLogging: ttl=alerting_threshold, ) + async def process_pre_call_hook_response(self, response, data, call_type): + if isinstance(response, Exception): + raise response + if isinstance(response, dict): + return response + if isinstance(response, str): + if call_type in ["completion", "text_completion"]: + raise RejectedRequestError( + message=response, + model=data.get("model", ""), + llm_provider="", + request_data=data, + ) + else: + raise HTTPException(status_code=400, detail={"error": response}) + return data + # The actual implementation of the function async def pre_call_hook( self, @@ -382,7 +400,33 @@ class ProxyLogging: ) else: _callback = callback # type: ignore + if ( + _callback is not None + and isinstance(_callback, CustomGuardrail) + and "pre_call_hook" in vars(_callback.__class__) + ): + from litellm.types.guardrails import GuardrailEventHooks + + if ( + _callback.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.pre_call + ) + is not True + ): + continue + response = await _callback.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=self.call_details["user_api_key_cache"], + data=data, + call_type=call_type, + ) + if response is not None: + data = await self.process_pre_call_hook_response( + response=response, data=data, call_type=call_type + ) + + elif ( _callback is not None and isinstance(_callback, CustomLogger) and "async_pre_call_hook" in vars(_callback.__class__) @@ -394,25 +438,9 @@ class ProxyLogging: call_type=call_type, ) if response is not None: - if isinstance(response, Exception): - raise response - elif isinstance(response, dict): - data = response - elif isinstance(response, str): - if ( - call_type == "completion" - or call_type == "text_completion" - ): - raise RejectedRequestError( - message=response, - model=data.get("model", ""), - llm_provider="", - request_data=data, - ) - else: - raise HTTPException( - status_code=400, detail={"error": response} - ) + data = await self.process_pre_call_hook_response( + response=response, data=data, call_type=call_type + ) return data except Exception as e: @@ -431,11 +459,30 @@ class ProxyLogging: ], ): """ - Runs the CustomLogger's async_moderation_hook() + Runs the CustomGuardrail's async_moderation_hook() """ for callback in litellm.callbacks: try: - if isinstance(callback, CustomLogger): + if isinstance(callback, CustomGuardrail): + ################################################################ + # Check if guardrail should be run for GuardrailEventHooks.during_call hook + ################################################################ + + # V1 implementation - backwards compatibility + if callback.event_hook is None: + if callback.moderation_check == "pre_call": + return + else: + # Main - V2 Guardrails implementation + from litellm.types.guardrails import GuardrailEventHooks + + if ( + callback.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.during_call + ) + is not True + ): + continue await callback.async_moderation_hook( data=data, user_api_key_dict=user_api_key_dict, @@ -737,12 +784,36 @@ class ProxyLogging: ) else: _callback = callback # type: ignore - if _callback is not None and isinstance(_callback, CustomLogger): - await _callback.async_post_call_success_hook( - user_api_key_dict=user_api_key_dict, - data=data, - response=response, - ) + + if _callback is not None: + ############## Handle Guardrails ######################################## + ############################################################################# + if isinstance(callback, CustomGuardrail): + # Main - V2 Guardrails implementation + from litellm.types.guardrails import GuardrailEventHooks + + if ( + callback.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.post_call + ) + is not True + ): + continue + + await callback.async_post_call_success_hook( + user_api_key_dict=user_api_key_dict, + data=data, + response=response, + ) + + ############ Handle CustomLogger ############################### + ################################################################# + elif isinstance(_callback, CustomLogger): + await _callback.async_post_call_success_hook( + user_api_key_dict=user_api_key_dict, + data=data, + response=response, + ) except Exception as e: raise e return response