Log applied guardrails on LLM API call (#8452)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 40s

* fix(litellm_logging.py): support saving applied guardrails in logging object

allows list of applied guardrails to be logged for proxy admin's knowledge

* feat(spend_tracking_utils.py): log applied guardrails to spend logs

makes it easy for admin to know what guardrails were applied on a request

* ci(config.yml): uninstall posthog from ci/cd

* test: fix tests

* test: update test
This commit is contained in:
Krish Dholakia 2025-02-10 22:57:30 -08:00 committed by GitHub
parent 8e32713637
commit ce3ead6f91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 152 additions and 10 deletions

View file

@ -199,6 +199,7 @@ class Logging(LiteLLMLoggingBaseClass):
dynamic_async_failure_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None,
applied_guardrails: Optional[List[str]] = None,
kwargs: Optional[Dict] = None,
):
_input: Optional[str] = messages # save original value of messages
@ -271,6 +272,7 @@ class Logging(LiteLLMLoggingBaseClass):
"litellm_call_id": litellm_call_id,
"input": _input,
"litellm_params": litellm_params,
"applied_guardrails": applied_guardrails,
}
def process_dynamic_callbacks(self):
@ -2852,6 +2854,7 @@ class StandardLoggingPayloadSetup:
metadata: Optional[Dict[str, Any]],
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -2866,6 +2869,7 @@ class StandardLoggingPayloadSetup:
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
"""
prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata
] = None
@ -2895,6 +2899,7 @@ class StandardLoggingPayloadSetup:
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
@ -3193,6 +3198,7 @@ def get_standard_logging_object_payload(
metadata=metadata,
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
)
_request_body = proxy_server_request.get("body", {})
@ -3328,6 +3334,7 @@ def get_standard_logging_metadata(
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=None,
applied_guardrails=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys

View file

@ -1794,6 +1794,7 @@ class SpendLogsMetadata(TypedDict):
dict
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
applied_guardrails: Optional[List[str]]
class SpendLogsPayload(TypedDict):

View file

@ -3,7 +3,7 @@ import secrets
from datetime import datetime
from datetime import datetime as dt
from datetime import timezone
from typing import Optional, cast
from typing import List, Optional, cast
from pydantic import BaseModel
@ -32,7 +32,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
return False
def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
def _get_spend_logs_metadata(
metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
user_api_key=None,
@ -44,8 +46,9 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
applied_guardrails=None,
)
verbose_proxy_logger.debug(
verbose_proxy_logger.info(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
@ -58,6 +61,8 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
if key in metadata
}
)
clean_metadata["applied_guardrails"] = applied_guardrails
return clean_metadata
@ -130,7 +135,14 @@ def get_logging_payload( # noqa: PLR0915
_model_group = metadata.get("model_group", "")
# clean up litellm metadata
clean_metadata = _get_spend_logs_metadata(metadata)
clean_metadata = _get_spend_logs_metadata(
metadata,
applied_guardrails=(
standard_logging_payload["metadata"].get("applied_guardrails", None)
if standard_logging_payload is not None
else None
),
)
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}

View file

@ -1525,6 +1525,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
applied_guardrails: Optional[List[str]]
class StandardLoggingAdditionalHeaders(TypedDict, total=False):

View file

@ -60,6 +60,7 @@ import litellm.litellm_core_utils.json_validation_rule
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
map_finish_reason,
@ -418,6 +419,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
)
def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
Get the request guardrails from the kwargs
"""
metadata = kwargs.get("metadata") or {}
requester_metadata = metadata.get("requester_metadata") or {}
applied_guardrails = requester_metadata.get("guardrails") or []
return applied_guardrails
def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
- Add 'default_on' guardrails to the list
- Add request guardrails to the list
"""
request_guardrails = get_request_guardrails(kwargs)
applied_guardrails = []
for callback in litellm.callbacks:
if callback is not None and isinstance(callback, CustomGuardrail):
if callback.guardrail_name is not None:
if callback.default_on is True:
applied_guardrails.append(callback.guardrail_name)
elif callback.guardrail_name in request_guardrails:
applied_guardrails.append(callback.guardrail_name)
return applied_guardrails
def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -436,6 +466,9 @@ def function_setup( # noqa: PLR0915
## CUSTOM LLM SETUP ##
custom_llm_setup()
## GET APPLIED GUARDRAILS
applied_guardrails = get_applied_guardrails(kwargs)
## LOGGING SETUP
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
@ -677,6 +710,7 @@ def function_setup( # noqa: PLR0915
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
kwargs=kwargs,
applied_guardrails=applied_guardrails,
)
## check if metadata is passed in

View file

@ -864,17 +864,24 @@ def test_convert_model_response_object():
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
)
@pytest.mark.parametrize(
"content, expected_reasoning, expected_content",
[
(None, None, None),
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
(
"<think>I am thinking here</think>The sky is a canvas of blue",
"I am thinking here",
"The sky is a canvas of blue",
),
("I am a regular response", None, "I am a regular response"),
]
],
)
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
assert litellm.utils._parse_content_for_reasoning(content) == (
expected_reasoning,
expected_content,
)
@pytest.mark.parametrize(
@ -1874,3 +1881,82 @@ def test_validate_user_messages_invalid_content_type():
assert "Invalid message" in str(e)
print(e)
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.utils import get_applied_guardrails
from unittest.mock import Mock
@pytest.mark.parametrize(
"test_case",
[
{
"name": "default_on_guardrail",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=True)
],
"kwargs": {"metadata": {"requester_metadata": {"guardrails": []}}},
"expected": ["test_guardrail"],
},
{
"name": "request_specific_guardrail",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
],
"kwargs": {
"metadata": {"requester_metadata": {"guardrails": ["test_guardrail"]}}
},
"expected": ["test_guardrail"],
},
{
"name": "multiple_guardrails",
"callbacks": [
CustomGuardrail(guardrail_name="default_guardrail", default_on=True),
CustomGuardrail(guardrail_name="request_guardrail", default_on=False),
],
"kwargs": {
"metadata": {
"requester_metadata": {"guardrails": ["request_guardrail"]}
}
},
"expected": ["default_guardrail", "request_guardrail"],
},
{
"name": "empty_metadata",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
],
"kwargs": {},
"expected": [],
},
{
"name": "none_callback",
"callbacks": [
None,
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
],
"kwargs": {},
"expected": ["test_guardrail"],
},
{
"name": "non_guardrail_callback",
"callbacks": [
Mock(),
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
],
"kwargs": {},
"expected": ["test_guardrail"],
},
],
)
def test_get_applied_guardrails(test_case):
# Setup
litellm.callbacks = test_case["callbacks"]
# Execute
result = get_applied_guardrails(test_case["kwargs"])
# Assert
assert sorted(result) == sorted(test_case["expected"])

View file

@ -9,7 +9,7 @@
"model": "gpt-4o",
"user": "",
"team_id": "",
"metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,

View file

@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
"metadata.applied_guardrails",
]
_all_attributes = set(