From 7d30188f84ab482677d1ab2bc251393f7fa251e8 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 23 Aug 2024 09:52:52 -0700 Subject: [PATCH] custom_callbacks --- .circleci/config.yml | 1 + .../example_config_yaml/custom_guardrail.py | 105 ++++++++++++++++++ .../example_config_yaml/otel_test_config.yaml | 14 ++- 3 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 litellm/proxy/example_config_yaml/custom_guardrail.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 27ab837c9..b562fbdd5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -328,6 +328,7 @@ jobs: -e APORIA_API_KEY_1=$APORIA_API_KEY_1 \ --name my-app \ -v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \ + -v $(pwd)/litellm/proxy/example_config_yaml/custom_callbacks.py:/app/custom_callbacks.py \ my-app:latest \ --config /app/config.yaml \ --port 4000 \ diff --git a/litellm/proxy/example_config_yaml/custom_guardrail.py b/litellm/proxy/example_config_yaml/custom_guardrail.py new file mode 100644 index 000000000..2ed989cfd --- /dev/null +++ b/litellm/proxy/example_config_yaml/custom_guardrail.py @@ -0,0 +1,105 @@ +from typing import Any, Dict, List, Literal, Optional, Union + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.types.guardrails import GuardrailEventHooks + + +class myCustomGuardrail(CustomGuardrail): + def __init__( + self, + **kwargs, + ): + # store kwargs as optional_params + self.optional_params = kwargs + + super().__init__(**kwargs) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + ], + ) -> Optional[Union[Exception, str, dict]]: + """ + Runs before the LLM API call + Runs on only Input + Use this if you want to MODIFY the input + """ + + # In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + _content = _content.replace("litellm", "********") + message["content"] = _content + + verbose_proxy_logger.debug( + "async_pre_call_hook: Message after masking %s", _messages + ) + + return data + + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + """ + Runs in parallel to LLM API call + Runs on only Input + + This can NOT modify the input, only used to reject or accept a call before going to LLM API + """ + + # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call + # In this guardrail, if a user inputs `litellm` we will mask it. + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + raise ValueError("Guardrail failed words - `litellm` detected") + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + """ + Runs on response from LLM API call + + It can be used to reject a response + + If a response contains the word "coffee" -> we will raise an exception + """ + verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) + if isinstance(response, litellm.ModelResponse): + for choice in response.choices: + if isinstance(choice, litellm.Choices): + verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice) + if ( + choice.message.content + and isinstance(choice.message.content, str) + and "coffee" in choice.message.content + ): + raise ValueError("Guardrail failed Coffee Detected") diff --git a/litellm/proxy/example_config_yaml/otel_test_config.yaml b/litellm/proxy/example_config_yaml/otel_test_config.yaml index 8ca4f37fd..a041a2bd0 100644 --- a/litellm/proxy/example_config_yaml/otel_test_config.yaml +++ b/litellm/proxy/example_config_yaml/otel_test_config.yaml @@ -27,4 +27,16 @@ guardrails: guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" mode: "pre_call" guardrailIdentifier: ff6ujrregl1q - guardrailVersion: "DRAFT" \ No newline at end of file + guardrailVersion: "DRAFT" + - guardrail_name: "custom-pre-guard" + litellm_params: + guardrail: custom_guardrail.myCustomGuardrail + mode: "pre_call" + - guardrail_name: "custom-during-guard" + litellm_params: + guardrail: custom_guardrail.myCustomGuardrail + mode: "during_call" + - guardrail_name: "custom-post-guard" + litellm_params: + guardrail: custom_guardrail.myCustomGuardrail + mode: "post_call" \ No newline at end of file