feat - return applied guardrails in response headers

This commit is contained in:
Ishaan Jaff 2024-08-19 11:56:20 -07:00
parent 4685b9909a
commit 613bd1babd
3 changed files with 38 additions and 1 deletions

View file

@ -137,13 +137,21 @@ class _ENTERPRISE_Aporio(CustomLogger):
user_api_key_dict: UserAPIKeyAuth,
response,
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
"""
Use this for the post call moderation with Guardrails
"""
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None:
await self.make_aporia_api_request(
response_string=response_str, new_messages=[]
response_string=response_str, new_messages=data.get("messages", [])
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=f"post_call_{GUARDRAIL_NAME}"
)
pass
@ -154,6 +162,9 @@ class _ENTERPRISE_Aporio(CustomLogger):
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
if (
await should_proceed_based_on_metadata(
@ -170,6 +181,9 @@ class _ENTERPRISE_Aporio(CustomLogger):
if new_messages is not None:
await self.make_aporia_api_request(new_messages=new_messages)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=f"during_call_{GUARDRAIL_NAME}"
)
else:
verbose_proxy_logger.warning(
"Aporia AI: not running guardrail. No messages in data"

View file

@ -295,3 +295,21 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens
return headers
def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
_metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata:
return {
"x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]),
}
return None
def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str):
_metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata:
_metadata["applied_guardrails"].append(guardrail_name)
else:
_metadata["applied_guardrails"] = [guardrail_name]

View file

@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.callback_utils import (
get_applied_guardrails_header,
get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy,
)
@ -536,6 +537,10 @@ def get_custom_headers(
)
headers.update(remaining_tokens_header)
applied_guardrails = get_applied_guardrails_header(request_data)
if applied_guardrails:
headers.update(applied_guardrails)
try:
return {
key: value for key, value in headers.items() if value not in exclude_values