From 613bd1babdd7000dc844bf8634fa212ac35bbd7a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 19 Aug 2024 11:56:20 -0700 Subject: [PATCH] feat - return applied guardrails in response headers --- enterprise/enterprise_hooks/aporio_ai.py | 16 +++++++++++++++- litellm/proxy/common_utils/callback_utils.py | 18 ++++++++++++++++++ litellm/proxy/proxy_server.py | 5 +++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/enterprise/enterprise_hooks/aporio_ai.py b/enterprise/enterprise_hooks/aporio_ai.py index 5d1b081fe..b0b1b50c9 100644 --- a/enterprise/enterprise_hooks/aporio_ai.py +++ b/enterprise/enterprise_hooks/aporio_ai.py @@ -137,13 +137,21 @@ class _ENTERPRISE_Aporio(CustomLogger): user_api_key_dict: UserAPIKeyAuth, response, ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + """ Use this for the post call moderation with Guardrails """ response_str: Optional[str] = convert_litellm_response_object_to_str(response) if response_str is not None: await self.make_aporia_api_request( - response_string=response_str, new_messages=[] + response_string=response_str, new_messages=data.get("messages", []) + ) + + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=f"post_call_{GUARDRAIL_NAME}" ) pass @@ -154,6 +162,9 @@ class _ENTERPRISE_Aporio(CustomLogger): user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) if ( await should_proceed_based_on_metadata( @@ -170,6 +181,9 @@ class _ENTERPRISE_Aporio(CustomLogger): if new_messages is not None: await self.make_aporia_api_request(new_messages=new_messages) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=f"during_call_{GUARDRAIL_NAME}" + ) else: verbose_proxy_logger.warning( "Aporia AI: not running guardrail. No messages in data" diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 6b000b148..ebf3194e2 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -295,3 +295,21 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens return headers + + +def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]: + _metadata = request_data.get("metadata", None) or {} + if "applied_guardrails" in _metadata: + return { + "x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]), + } + + return None + + +def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str): + _metadata = request_data.get("metadata", None) or {} + if "applied_guardrails" in _metadata: + _metadata["applied_guardrails"].append(guardrail_name) + else: + _metadata["applied_guardrails"] = [guardrail_name] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index adecb2eb6..a759dd973 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import ( show_missing_vars_in_env, ) from litellm.proxy.common_utils.callback_utils import ( + get_applied_guardrails_header, get_remaining_tokens_and_requests_from_request_data, initialize_callbacks_on_proxy, ) @@ -536,6 +537,10 @@ def get_custom_headers( ) headers.update(remaining_tokens_header) + applied_guardrails = get_applied_guardrails_header(request_data) + if applied_guardrails: + headers.update(applied_guardrails) + try: return { key: value for key, value in headers.items() if value not in exclude_values