diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index 28182b75ac..516d9dba34 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -199,6 +199,7 @@ class Logging(LiteLLMLoggingBaseClass):
dynamic_async_failure_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None,
+ applied_guardrails: Optional[List[str]] = None,
kwargs: Optional[Dict] = None,
):
_input: Optional[str] = messages # save original value of messages
@@ -271,6 +272,7 @@ class Logging(LiteLLMLoggingBaseClass):
"litellm_call_id": litellm_call_id,
"input": _input,
"litellm_params": litellm_params,
+ "applied_guardrails": applied_guardrails,
}
def process_dynamic_callbacks(self):
@@ -2852,6 +2854,7 @@ class StandardLoggingPayloadSetup:
metadata: Optional[Dict[str, Any]],
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
+ applied_guardrails: Optional[List[str]] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@@ -2866,6 +2869,7 @@ class StandardLoggingPayloadSetup:
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
"""
+
prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata
] = None
@@ -2895,6 +2899,7 @@ class StandardLoggingPayloadSetup:
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
+ applied_guardrails=applied_guardrails,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
@@ -3193,6 +3198,7 @@ def get_standard_logging_object_payload(
metadata=metadata,
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
+ applied_guardrails=kwargs.get("applied_guardrails", None),
)
_request_body = proxy_server_request.get("body", {})
@@ -3328,6 +3334,7 @@ def get_standard_logging_metadata(
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=None,
+ applied_guardrails=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index 75494bcf78..27d04c19c1 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -1794,6 +1794,7 @@ class SpendLogsMetadata(TypedDict):
dict
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
+ applied_guardrails: Optional[List[str]]
class SpendLogsPayload(TypedDict):
diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py
index ccf0836e05..f12220766b 100644
--- a/litellm/proxy/spend_tracking/spend_tracking_utils.py
+++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py
@@ -3,7 +3,7 @@ import secrets
from datetime import datetime
from datetime import datetime as dt
from datetime import timezone
-from typing import Optional, cast
+from typing import List, Optional, cast
from pydantic import BaseModel
@@ -32,7 +32,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
return False
-def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
+def _get_spend_logs_metadata(
+ metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
+) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
user_api_key=None,
@@ -44,8 +46,9 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
+ applied_guardrails=None,
)
- verbose_proxy_logger.debug(
+ verbose_proxy_logger.info(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
@@ -58,6 +61,8 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
if key in metadata
}
)
+ clean_metadata["applied_guardrails"] = applied_guardrails
+
return clean_metadata
@@ -130,7 +135,14 @@ def get_logging_payload( # noqa: PLR0915
_model_group = metadata.get("model_group", "")
# clean up litellm metadata
- clean_metadata = _get_spend_logs_metadata(metadata)
+ clean_metadata = _get_spend_logs_metadata(
+ metadata,
+ applied_guardrails=(
+ standard_logging_payload["metadata"].get("applied_guardrails", None)
+ if standard_logging_payload is not None
+ else None
+ ),
+ )
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index 556bae94e7..9139f6be4c 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -1525,6 +1525,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
+ applied_guardrails: Optional[List[str]]
class StandardLoggingAdditionalHeaders(TypedDict, total=False):
diff --git a/litellm/utils.py b/litellm/utils.py
index 7cdfc2ebbe..dbefc90dc6 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -60,6 +60,7 @@ import litellm.litellm_core_utils.json_validation_rule
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
+from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
map_finish_reason,
@@ -418,6 +419,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
)
+def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
+ """
+ Get the request guardrails from the kwargs
+ """
+ metadata = kwargs.get("metadata") or {}
+ requester_metadata = metadata.get("requester_metadata") or {}
+ applied_guardrails = requester_metadata.get("guardrails") or []
+ return applied_guardrails
+
+
+def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
+ """
+ - Add 'default_on' guardrails to the list
+ - Add request guardrails to the list
+ """
+
+ request_guardrails = get_request_guardrails(kwargs)
+ applied_guardrails = []
+ for callback in litellm.callbacks:
+ if callback is not None and isinstance(callback, CustomGuardrail):
+ if callback.guardrail_name is not None:
+ if callback.default_on is True:
+ applied_guardrails.append(callback.guardrail_name)
+ elif callback.guardrail_name in request_guardrails:
+ applied_guardrails.append(callback.guardrail_name)
+
+ return applied_guardrails
+
+
def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@@ -436,6 +466,9 @@ def function_setup( # noqa: PLR0915
## CUSTOM LLM SETUP ##
custom_llm_setup()
+ ## GET APPLIED GUARDRAILS
+ applied_guardrails = get_applied_guardrails(kwargs)
+
## LOGGING SETUP
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
@@ -677,6 +710,7 @@ def function_setup( # noqa: PLR0915
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
kwargs=kwargs,
+ applied_guardrails=applied_guardrails,
)
## check if metadata is passed in
diff --git a/tests/litellm_utils_tests/test_utils.py b/tests/litellm_utils_tests/test_utils.py
index 75630c81d8..b19282563c 100644
--- a/tests/litellm_utils_tests/test_utils.py
+++ b/tests/litellm_utils_tests/test_utils.py
@@ -864,17 +864,24 @@ def test_convert_model_response_object():
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
)
+
@pytest.mark.parametrize(
- "content, expected_reasoning, expected_content",
+ "content, expected_reasoning, expected_content",
[
(None, None, None),
- ("I am thinking hereThe sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
+ (
+ "I am thinking hereThe sky is a canvas of blue",
+ "I am thinking here",
+ "The sky is a canvas of blue",
+ ),
("I am a regular response", None, "I am a regular response"),
-
- ]
+ ],
)
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
- assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
+ assert litellm.utils._parse_content_for_reasoning(content) == (
+ expected_reasoning,
+ expected_content,
+ )
@pytest.mark.parametrize(
@@ -1874,3 +1881,82 @@ def test_validate_user_messages_invalid_content_type():
assert "Invalid message" in str(e)
print(e)
+
+
+from litellm.integrations.custom_guardrail import CustomGuardrail
+from litellm.utils import get_applied_guardrails
+from unittest.mock import Mock
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ {
+ "name": "default_on_guardrail",
+ "callbacks": [
+ CustomGuardrail(guardrail_name="test_guardrail", default_on=True)
+ ],
+ "kwargs": {"metadata": {"requester_metadata": {"guardrails": []}}},
+ "expected": ["test_guardrail"],
+ },
+ {
+ "name": "request_specific_guardrail",
+ "callbacks": [
+ CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
+ ],
+ "kwargs": {
+ "metadata": {"requester_metadata": {"guardrails": ["test_guardrail"]}}
+ },
+ "expected": ["test_guardrail"],
+ },
+ {
+ "name": "multiple_guardrails",
+ "callbacks": [
+ CustomGuardrail(guardrail_name="default_guardrail", default_on=True),
+ CustomGuardrail(guardrail_name="request_guardrail", default_on=False),
+ ],
+ "kwargs": {
+ "metadata": {
+ "requester_metadata": {"guardrails": ["request_guardrail"]}
+ }
+ },
+ "expected": ["default_guardrail", "request_guardrail"],
+ },
+ {
+ "name": "empty_metadata",
+ "callbacks": [
+ CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
+ ],
+ "kwargs": {},
+ "expected": [],
+ },
+ {
+ "name": "none_callback",
+ "callbacks": [
+ None,
+ CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
+ ],
+ "kwargs": {},
+ "expected": ["test_guardrail"],
+ },
+ {
+ "name": "non_guardrail_callback",
+ "callbacks": [
+ Mock(),
+ CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
+ ],
+ "kwargs": {},
+ "expected": ["test_guardrail"],
+ },
+ ],
+)
+def test_get_applied_guardrails(test_case):
+
+ # Setup
+ litellm.callbacks = test_case["callbacks"]
+
+ # Execute
+ result = get_applied_guardrails(test_case["kwargs"])
+
+ # Assert
+ assert sorted(result) == sorted(test_case["expected"])
diff --git a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json
index 0e78e60b76..08c6b45183 100644
--- a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json
+++ b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json
@@ -9,7 +9,7 @@
"model": "gpt-4o",
"user": "",
"team_id": "",
- "metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
+ "metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,
diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py
index 9c19c9d261..d37e46bf19 100644
--- a/tests/logging_callback_tests/test_otel_logging.py
+++ b/tests/logging_callback_tests/test_otel_logging.py
@@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
+ "metadata.applied_guardrails",
]
_all_attributes = set(