diff --git a/docs/my-website/docs/proxy/guardrails/quick_start.md b/docs/my-website/docs/proxy/guardrails/quick_start.md index 046b8ac422..29f8a7b551 100644 --- a/docs/my-website/docs/proxy/guardrails/quick_start.md +++ b/docs/my-website/docs/proxy/guardrails/quick_start.md @@ -114,6 +114,88 @@ curl -i http://localhost:4000/v1/chat/completions \ ## Advanced + +### ✨ Pass additional parameters to guardrail + +:::info + +✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) + +::: + + + +Use this to pass additional parameters to the guardrail API call. e.g. things like success threshold. **[See `guardrails` spec for more details](#spec-guardrails-parameter)** + + + + + + +Set `guardrails={"aporia-pre-guard": {"extra_body": {"success_threshold": 0.9}}}` to pass additional parameters to the guardrail + +In this example `success_threshold=0.9` is passed to the `aporia-pre-guard` guardrail request body + +```python +import openai +client = openai.OpenAI( + api_key="anything", + base_url="http://0.0.0.0:4000" +) + +response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } + ], + extra_body={ + "guardrails": [ + "aporia-pre-guard": { + "extra_body": { + "success_threshold": 0.9 + } + } + ] + } + +) + +print(response) +``` + + + + + +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + "guardrails": [ + "aporia-pre-guard": { + "extra_body": { + "success_threshold": 0.9 + } + } + ] +}' +``` + + + + + + ### ✨ Control Guardrails per Project (API Key) :::info @@ -252,4 +334,43 @@ Expected response { "guardrails": ["aporia-pre-guard", "aporia-post-guard"] } +``` + +## Spec: `guardrails` Parameter + +The `guardrails` parameter can be passed to any LiteLLM Proxy endpoint (`/chat/completions`, `/completions`, `/embeddings`). + +### Format Options + +1. Simple List Format: +```python +"guardrails": [ + "aporia-pre-guard", + "aporia-post-guard" +] +``` + +2. Advanced Dictionary Format: + +In this format the dictionary key is `guardrail_name` you want to run +```python +"guardrails": { + "aporia-pre-guard": { + "extra_body": { + "success_threshold": 0.9, + "other_param": "value" + } + } +} +``` + +### Type Definition +```python +guardrails: Union[ + List[str], # Simple list of guardrail names + Dict[str, DynamicGuardrailParams] # Advanced configuration +] + +class DynamicGuardrailParams: + extra_body: Dict[str, Any] # Additional parameters for the guardrail ``` \ No newline at end of file diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 39f762533d..7706d9ab63 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -1,8 +1,8 @@ -from typing import List, Optional +from typing import Dict, List, Optional, Union from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.types.guardrails import GuardrailEventHooks +from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks class CustomGuardrail(CustomLogger): @@ -26,9 +26,31 @@ class CustomGuardrail(CustomLogger): ) super().__init__(**kwargs) - def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: + def get_guardrail_from_metadata( + self, data: dict + ) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]: + """ + Returns the guardrail(s) to be run from the metadata + """ metadata = data.get("metadata") or {} requested_guardrails = metadata.get("guardrails") or [] + return requested_guardrails + + def _guardrail_is_in_requested_guardrails( + self, + requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]], + ) -> bool: + for _guardrail in requested_guardrails: + if isinstance(_guardrail, dict): + if self.guardrail_name in _guardrail: + return True + elif isinstance(_guardrail, str): + if self.guardrail_name == _guardrail: + return True + return False + + def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: + requested_guardrails = self.get_guardrail_from_metadata(data) verbose_logger.debug( "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s", @@ -40,7 +62,7 @@ class CustomGuardrail(CustomLogger): if ( self.event_hook - and self.guardrail_name not in requested_guardrails + and not self._guardrail_is_in_requested_guardrails(requested_guardrails) and event_type.value != "logging_only" ): return False @@ -49,3 +71,51 @@ class CustomGuardrail(CustomLogger): return False return True + + def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict: + """ + Returns `extra_body` to be added to the request body for the Guardrail API call + + Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc. + + ``` + [{"lakera_guard": {"extra_body": {"foo": "bar"}}}] + ``` + + Will return: for guardrail=`lakera-guard`: + { + "foo": "bar" + } + + Args: + request_data: The original `request_data` passed to LiteLLM Proxy + """ + requested_guardrails = self.get_guardrail_from_metadata(request_data) + + # Look for the guardrail configuration matching self.guardrail_name + for guardrail in requested_guardrails: + if isinstance(guardrail, dict) and self.guardrail_name in guardrail: + # Get the configuration for this guardrail + guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams( + **guardrail[self.guardrail_name] + ) + if self._validate_premium_user() is not True: + return {} + + # Return the extra_body if it exists, otherwise empty dict + return guardrail_config.get("extra_body", {}) + + return {} + + def _validate_premium_user(self) -> bool: + """ + Returns True if the user is a premium user + """ + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + verbose_logger.warning( + f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}" + ) + return False + return True diff --git a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py index 6ead4f0d02..9e3fdde3be 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py @@ -86,12 +86,19 @@ class AporiaGuardrail(CustomGuardrail): return data async def make_aporia_api_request( - self, new_messages: List[dict], response_string: Optional[str] = None + self, + request_data: dict, + new_messages: List[dict], + response_string: Optional[str] = None, ): data = await self.prepare_aporia_request( new_messages=new_messages, response_string=response_string ) + data.update( + self.get_guardrail_dynamic_request_body_params(request_data=request_data) + ) + _json_data = json.dumps(data) """ @@ -155,7 +162,9 @@ class AporiaGuardrail(CustomGuardrail): response_str: Optional[str] = convert_litellm_response_object_to_str(response) if response_str is not None: await self.make_aporia_api_request( - response_string=response_str, new_messages=data.get("messages", []) + request_data=data, + response_string=response_str, + new_messages=data.get("messages", []), ) add_guardrail_to_applied_guardrails_header( @@ -199,7 +208,10 @@ class AporiaGuardrail(CustomGuardrail): new_messages = self.transform_messages(messages=data["messages"]) if new_messages is not None: - await self.make_aporia_api_request(new_messages=new_messages) + await self.make_aporia_api_request( + request_data=data, + new_messages=new_messages, + ) add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 4668b17284..1b6880581c 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -149,7 +149,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): def _prepare_request( self, credentials, - data: BedrockRequest, + data: dict, optional_params: dict, aws_region_name: str, extra_headers: Optional[dict] = None, @@ -186,18 +186,23 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): ): credentials, aws_region_name = self._load_credentials() - request_data: BedrockRequest = self.convert_to_bedrock_format( - messages=kwargs.get("messages"), response=response + bedrock_request_data: dict = dict( + self.convert_to_bedrock_format( + messages=kwargs.get("messages"), response=response + ) + ) + bedrock_request_data.update( + self.get_guardrail_dynamic_request_body_params(request_data=kwargs) ) prepared_request = self._prepare_request( credentials=credentials, - data=request_data, + data=bedrock_request_data, optional_params=self.optional_params, aws_region_name=aws_region_name, ) verbose_proxy_logger.debug( "Bedrock AI request body: %s, url %s, headers: %s", - request_data, + bedrock_request_data, prepared_request.url, prepared_request.headers, ) diff --git a/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py b/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py index 092fbe8ea5..1500e8c25a 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py @@ -48,10 +48,13 @@ class GuardrailsAI(CustomGuardrail): supported_event_hooks = [GuardrailEventHooks.post_call] super().__init__(supported_event_hooks=supported_event_hooks, **kwargs) - async def make_guardrails_ai_api_request(self, llm_output: str): + async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict): from httpx import URL - data = {"llmOutput": llm_output} + data = { + "llmOutput": llm_output, + **self.get_guardrail_dynamic_request_body_params(request_data=request_data), + } _json_data = json.dumps(data) response = await litellm.module_level_aclient.post( url=str( @@ -96,7 +99,9 @@ class GuardrailsAI(CustomGuardrail): response_str: str = get_content_from_model_response(response) if response_str is not None and len(response_str) > 0: - await self.make_guardrails_ai_api_request(llm_output=response_str) + await self.make_guardrails_ai_api_request( + llm_output=response_str, request_data=data + ) add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index 14e0a7eee6..6f05d366fa 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -216,14 +216,27 @@ class lakeraAI_Moderation(CustomGuardrail): "Skipping lakera prompt injection, no roles with messages found" ) return - data = {"input": lakera_input} - _json_data = json.dumps(data) + _data = {"input": lakera_input} + _json_data = json.dumps( + _data, + **self.get_guardrail_dynamic_request_body_params(request_data=data), + ) elif "input" in data and isinstance(data["input"], str): text = data["input"] - _json_data = json.dumps({"input": text}) + _json_data = json.dumps( + { + "input": text, + **self.get_guardrail_dynamic_request_body_params(request_data=data), + } + ) elif "input" in data and isinstance(data["input"], list): text = "\n".join(data["input"]) - _json_data = json.dumps({"input": text}) + _json_data = json.dumps( + { + "input": text, + **self.get_guardrail_dynamic_request_body_params(request_data=data), + } + ) verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data) diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index a585d43e6d..fb58170cf5 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -132,6 +132,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): text: str, output_parse_pii: bool, presidio_config: Optional[PresidioPerRequestConfig], + request_data: dict, ) -> str: """ [TODO] make this more performant for high-throughput scenario @@ -150,7 +151,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): if self.ad_hoc_recognizers is not None: analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers # End of constructing Request 1 - + analyze_payload.update( + self.get_guardrail_dynamic_request_body_params( + request_data=request_data + ) + ) redacted_text = None verbose_proxy_logger.debug( "Making request to: %s with payload: %s", @@ -235,6 +240,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): text=m["content"], output_parse_pii=self.output_parse_pii, presidio_config=presidio_config, + request_data=data, ) ) responses = await asyncio.gather(*tasks) @@ -311,6 +317,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): text=text_str, output_parse_pii=False, presidio_config=presidio_config, + request_data=kwargs, ) ) # need to pass separately b/c presidio has context window limits responses = await asyncio.gather(*tasks) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index c1f6665b51..fc44c3a24d 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -12,6 +12,14 @@ model_list: model: bedrock/* +guardrails: + - guardrail_name: "bedrock-pre-guard" + litellm_params: + guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" + mode: "during_call" + guardrailIdentifier: ff6ujrregl1q + guardrailVersion: "DRAFT" + # for /files endpoints # For /fine_tuning/jobs endpoints finetune_settings: diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 29e7321ab9..cdf2cdb9d8 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Literal, Optional, TypedDict +from typing import Any, Dict, List, Literal, Optional, TypedDict from pydantic import BaseModel, ConfigDict from typing_extensions import Required, TypedDict @@ -132,3 +132,7 @@ class BedrockContentItem(TypedDict, total=False): class BedrockRequest(TypedDict, total=False): source: Literal["INPUT", "OUTPUT"] content: List[BedrockContentItem] + + +class DynamicGuardrailParams(TypedDict): + extra_body: Dict[str, Any] diff --git a/tests/logging_callback_tests/test_custom_guardrail.py b/tests/logging_callback_tests/test_custom_guardrail.py new file mode 100644 index 0000000000..f8995e5624 --- /dev/null +++ b/tests/logging_callback_tests/test_custom_guardrail.py @@ -0,0 +1,145 @@ +import io +import os +import sys + + +sys.path.insert(0, os.path.abspath("../..")) + +import asyncio +import gzip +import json +import logging +import time +from unittest.mock import AsyncMock, patch + +import pytest + +import litellm +from litellm import completion +from litellm._logging import verbose_logger +from litellm.integrations.custom_guardrail import CustomGuardrail + + +from typing import Any, Dict, List, Literal, Optional, Union + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.types.guardrails import GuardrailEventHooks + + +def test_get_guardrail_from_metadata(): + guardrail = CustomGuardrail(guardrail_name="test-guardrail") + + # Test with empty metadata + assert guardrail.get_guardrail_from_metadata({}) == [] + + # Test with guardrails in metadata + data = {"metadata": {"guardrails": ["guardrail1", "guardrail2"]}} + assert guardrail.get_guardrail_from_metadata(data) == ["guardrail1", "guardrail2"] + + # Test with dict guardrails + data = { + "metadata": { + "guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}] + } + } + assert guardrail.get_guardrail_from_metadata(data) == [ + {"test-guardrail": {"extra_body": {"key": "value"}}} + ] + + +def test_guardrail_is_in_requested_guardrails(): + guardrail = CustomGuardrail(guardrail_name="test-guardrail") + + # Test with string list + assert ( + guardrail._guardrail_is_in_requested_guardrails(["test-guardrail", "other"]) + == True + ) + assert guardrail._guardrail_is_in_requested_guardrails(["other"]) == False + + # Test with dict list + assert ( + guardrail._guardrail_is_in_requested_guardrails( + [{"test-guardrail": {"extra_body": {"extra_key": "extra_value"}}}] + ) + == True + ) + assert ( + guardrail._guardrail_is_in_requested_guardrails( + [ + { + "other-guardrail": {"extra_body": {"extra_key": "extra_value"}}, + "test-guardrail": {"extra_body": {"extra_key": "extra_value"}}, + } + ] + ) + == True + ) + assert ( + guardrail._guardrail_is_in_requested_guardrails( + [{"other-guardrail": {"extra_body": {"extra_key": "extra_value"}}}] + ) + == False + ) + + +def test_should_run_guardrail(): + guardrail = CustomGuardrail( + guardrail_name="test-guardrail", event_hook=GuardrailEventHooks.pre_call + ) + + # Test matching event hook and guardrail + assert ( + guardrail.should_run_guardrail( + {"metadata": {"guardrails": ["test-guardrail"]}}, + GuardrailEventHooks.pre_call, + ) + == True + ) + + # Test non-matching event hook + assert ( + guardrail.should_run_guardrail( + {"metadata": {"guardrails": ["test-guardrail"]}}, + GuardrailEventHooks.during_call, + ) + == False + ) + + # Test guardrail not in requested list + assert ( + guardrail.should_run_guardrail( + {"metadata": {"guardrails": ["other-guardrail"]}}, + GuardrailEventHooks.pre_call, + ) + == False + ) + + +def test_get_guardrail_dynamic_request_body_params(): + guardrail = CustomGuardrail(guardrail_name="test-guardrail") + + # Test with no extra_body + data = {"metadata": {"guardrails": [{"test-guardrail": {}}]}} + assert guardrail.get_guardrail_dynamic_request_body_params(data) == {} + + # Test with extra_body + data = { + "metadata": { + "guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}] + } + } + assert guardrail.get_guardrail_dynamic_request_body_params(data) == {"key": "value"} + + # Test with non-matching guardrail + data = { + "metadata": { + "guardrails": [{"other-guardrail": {"extra_body": {"key": "value"}}}] + } + } + assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}