mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Log applied guardrails on LLM API call (#8452)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 40s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 40s
* fix(litellm_logging.py): support saving applied guardrails in logging object allows list of applied guardrails to be logged for proxy admin's knowledge * feat(spend_tracking_utils.py): log applied guardrails to spend logs makes it easy for admin to know what guardrails were applied on a request * ci(config.yml): uninstall posthog from ci/cd * test: fix tests * test: update test
This commit is contained in:
parent
8e32713637
commit
ce3ead6f91
8 changed files with 152 additions and 10 deletions
|
@ -199,6 +199,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
dynamic_async_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None,
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
_input: Optional[str] = messages # save original value of messages
|
||||
|
@ -271,6 +272,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"litellm_call_id": litellm_call_id,
|
||||
"input": _input,
|
||||
"litellm_params": litellm_params,
|
||||
"applied_guardrails": applied_guardrails,
|
||||
}
|
||||
|
||||
def process_dynamic_callbacks(self):
|
||||
|
@ -2852,6 +2854,7 @@ class StandardLoggingPayloadSetup:
|
|||
metadata: Optional[Dict[str, Any]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
prompt_integration: Optional[str] = None,
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
@ -2866,6 +2869,7 @@ class StandardLoggingPayloadSetup:
|
|||
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
|
||||
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
|
||||
"""
|
||||
|
||||
prompt_management_metadata: Optional[
|
||||
StandardLoggingPromptManagementMetadata
|
||||
] = None
|
||||
|
@ -2895,6 +2899,7 @@ class StandardLoggingPayloadSetup:
|
|||
requester_metadata=None,
|
||||
user_api_key_end_user_id=None,
|
||||
prompt_management_metadata=prompt_management_metadata,
|
||||
applied_guardrails=applied_guardrails,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
@ -3193,6 +3198,7 @@ def get_standard_logging_object_payload(
|
|||
metadata=metadata,
|
||||
litellm_params=litellm_params,
|
||||
prompt_integration=kwargs.get("prompt_integration", None),
|
||||
applied_guardrails=kwargs.get("applied_guardrails", None),
|
||||
)
|
||||
|
||||
_request_body = proxy_server_request.get("body", {})
|
||||
|
@ -3328,6 +3334,7 @@ def get_standard_logging_metadata(
|
|||
requester_metadata=None,
|
||||
user_api_key_end_user_id=None,
|
||||
prompt_management_metadata=None,
|
||||
applied_guardrails=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
|
|
@ -1794,6 +1794,7 @@ class SpendLogsMetadata(TypedDict):
|
|||
dict
|
||||
] # special param to log k,v pairs to spendlogs for a call
|
||||
requester_ip_address: Optional[str]
|
||||
applied_guardrails: Optional[List[str]]
|
||||
|
||||
|
||||
class SpendLogsPayload(TypedDict):
|
||||
|
|
|
@ -3,7 +3,7 @@ import secrets
|
|||
from datetime import datetime
|
||||
from datetime import datetime as dt
|
||||
from datetime import timezone
|
||||
from typing import Optional, cast
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -32,7 +32,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
|
||||
def _get_spend_logs_metadata(
|
||||
metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
|
||||
) -> SpendLogsMetadata:
|
||||
if metadata is None:
|
||||
return SpendLogsMetadata(
|
||||
user_api_key=None,
|
||||
|
@ -44,8 +46,9 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
|
|||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
additional_usage_values=None,
|
||||
applied_guardrails=None,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
verbose_proxy_logger.info(
|
||||
"getting payload for SpendLogs, available keys in metadata: "
|
||||
+ str(list(metadata.keys()))
|
||||
)
|
||||
|
@ -58,6 +61,8 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
|
|||
if key in metadata
|
||||
}
|
||||
)
|
||||
clean_metadata["applied_guardrails"] = applied_guardrails
|
||||
|
||||
return clean_metadata
|
||||
|
||||
|
||||
|
@ -130,7 +135,14 @@ def get_logging_payload( # noqa: PLR0915
|
|||
_model_group = metadata.get("model_group", "")
|
||||
|
||||
# clean up litellm metadata
|
||||
clean_metadata = _get_spend_logs_metadata(metadata)
|
||||
clean_metadata = _get_spend_logs_metadata(
|
||||
metadata,
|
||||
applied_guardrails=(
|
||||
standard_logging_payload["metadata"].get("applied_guardrails", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||
additional_usage_values = {}
|
||||
|
|
|
@ -1525,6 +1525,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
|||
requester_ip_address: Optional[str]
|
||||
requester_metadata: Optional[dict]
|
||||
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
|
||||
applied_guardrails: Optional[List[str]]
|
||||
|
||||
|
||||
class StandardLoggingAdditionalHeaders(TypedDict, total=False):
|
||||
|
|
|
@ -60,6 +60,7 @@ import litellm.litellm_core_utils.json_validation_rule
|
|||
from litellm.caching._internal_lru_cache import lru_cache_wrapper
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
map_finish_reason,
|
||||
|
@ -418,6 +419,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
|
|||
)
|
||||
|
||||
|
||||
def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Get the request guardrails from the kwargs
|
||||
"""
|
||||
metadata = kwargs.get("metadata") or {}
|
||||
requester_metadata = metadata.get("requester_metadata") or {}
|
||||
applied_guardrails = requester_metadata.get("guardrails") or []
|
||||
return applied_guardrails
|
||||
|
||||
|
||||
def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
- Add 'default_on' guardrails to the list
|
||||
- Add request guardrails to the list
|
||||
"""
|
||||
|
||||
request_guardrails = get_request_guardrails(kwargs)
|
||||
applied_guardrails = []
|
||||
for callback in litellm.callbacks:
|
||||
if callback is not None and isinstance(callback, CustomGuardrail):
|
||||
if callback.guardrail_name is not None:
|
||||
if callback.default_on is True:
|
||||
applied_guardrails.append(callback.guardrail_name)
|
||||
elif callback.guardrail_name in request_guardrails:
|
||||
applied_guardrails.append(callback.guardrail_name)
|
||||
|
||||
return applied_guardrails
|
||||
|
||||
|
||||
def function_setup( # noqa: PLR0915
|
||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -436,6 +466,9 @@ def function_setup( # noqa: PLR0915
|
|||
## CUSTOM LLM SETUP ##
|
||||
custom_llm_setup()
|
||||
|
||||
## GET APPLIED GUARDRAILS
|
||||
applied_guardrails = get_applied_guardrails(kwargs)
|
||||
|
||||
## LOGGING SETUP
|
||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
|
@ -677,6 +710,7 @@ def function_setup( # noqa: PLR0915
|
|||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
|
||||
kwargs=kwargs,
|
||||
applied_guardrails=applied_guardrails,
|
||||
)
|
||||
|
||||
## check if metadata is passed in
|
||||
|
|
|
@ -864,17 +864,24 @@ def test_convert_model_response_object():
|
|||
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content, expected_reasoning, expected_content",
|
||||
[
|
||||
(None, None, None),
|
||||
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
|
||||
(
|
||||
"<think>I am thinking here</think>The sky is a canvas of blue",
|
||||
"I am thinking here",
|
||||
"The sky is a canvas of blue",
|
||||
),
|
||||
("I am a regular response", None, "I am a regular response"),
|
||||
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
|
||||
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
|
||||
assert litellm.utils._parse_content_for_reasoning(content) == (
|
||||
expected_reasoning,
|
||||
expected_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1874,3 +1881,82 @@ def test_validate_user_messages_invalid_content_type():
|
|||
|
||||
assert "Invalid message" in str(e)
|
||||
print(e)
|
||||
|
||||
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.utils import get_applied_guardrails
|
||||
from unittest.mock import Mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
{
|
||||
"name": "default_on_guardrail",
|
||||
"callbacks": [
|
||||
CustomGuardrail(guardrail_name="test_guardrail", default_on=True)
|
||||
],
|
||||
"kwargs": {"metadata": {"requester_metadata": {"guardrails": []}}},
|
||||
"expected": ["test_guardrail"],
|
||||
},
|
||||
{
|
||||
"name": "request_specific_guardrail",
|
||||
"callbacks": [
|
||||
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
|
||||
],
|
||||
"kwargs": {
|
||||
"metadata": {"requester_metadata": {"guardrails": ["test_guardrail"]}}
|
||||
},
|
||||
"expected": ["test_guardrail"],
|
||||
},
|
||||
{
|
||||
"name": "multiple_guardrails",
|
||||
"callbacks": [
|
||||
CustomGuardrail(guardrail_name="default_guardrail", default_on=True),
|
||||
CustomGuardrail(guardrail_name="request_guardrail", default_on=False),
|
||||
],
|
||||
"kwargs": {
|
||||
"metadata": {
|
||||
"requester_metadata": {"guardrails": ["request_guardrail"]}
|
||||
}
|
||||
},
|
||||
"expected": ["default_guardrail", "request_guardrail"],
|
||||
},
|
||||
{
|
||||
"name": "empty_metadata",
|
||||
"callbacks": [
|
||||
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
|
||||
],
|
||||
"kwargs": {},
|
||||
"expected": [],
|
||||
},
|
||||
{
|
||||
"name": "none_callback",
|
||||
"callbacks": [
|
||||
None,
|
||||
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
|
||||
],
|
||||
"kwargs": {},
|
||||
"expected": ["test_guardrail"],
|
||||
},
|
||||
{
|
||||
"name": "non_guardrail_callback",
|
||||
"callbacks": [
|
||||
Mock(),
|
||||
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
|
||||
],
|
||||
"kwargs": {},
|
||||
"expected": ["test_guardrail"],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_get_applied_guardrails(test_case):
|
||||
|
||||
# Setup
|
||||
litellm.callbacks = test_case["callbacks"]
|
||||
|
||||
# Execute
|
||||
result = get_applied_guardrails(test_case["kwargs"])
|
||||
|
||||
# Assert
|
||||
assert sorted(result) == sorted(test_case["expected"])
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
"model": "gpt-4o",
|
||||
"user": "",
|
||||
"team_id": "",
|
||||
"metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||
"metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||
"cache_key": "Cache OFF",
|
||||
"spend": 0.00022500000000000002,
|
||||
"total_tokens": 30,
|
||||
|
|
|
@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
|
|||
"metadata.user_api_key_user_id",
|
||||
"metadata.user_api_key_org_id",
|
||||
"metadata.user_api_key_end_user_id",
|
||||
"metadata.applied_guardrails",
|
||||
]
|
||||
|
||||
_all_attributes = set(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue