(Feat) Log Guardrails run, guardrail response on logging integrations (#7445)

* add guardrail_information to SLP

* use standard_logging_guardrail_information

* track StandardLoggingGuardrailInformation

* use log_guardrail_information

* use log_guardrail_information

* docs guardrails

* docs guardrails

* update quick start

* fix presidio logging for sync functions

* update Guardrail type

* enforce add_standard_logging_guardrail_information_to_request_data

* update gd docs
This commit is contained in:
Ishaan Jaff 2024-12-27 15:01:56 -08:00 committed by GitHub
parent 9efb076037
commit 6ec5ed8b3c
14 changed files with 223 additions and 29 deletions

View file

@ -112,14 +112,49 @@ curl -i http://localhost:4000/v1/chat/completions \
</Tabs> </Tabs>
## **Using Guardrails client side**
### ✨ View available guardrails (/guardrails/list)
Show available guardrails on the proxy server. This makes it easier for developers to know what guardrails are available / can be used.
```shell
curl -X GET 'http://0.0.0.0:4000/guardrails/list'
```
Expected response
```json
{
"guardrails": [
{
"guardrail_name": "bedrock-pre-guard",
"guardrail_info": {
"params": [
{
"name": "toxicity_score",
"type": "float",
"description": "Score between 0-1 indicating content toxicity level"
},
{
"name": "pii_detection",
"type": "boolean"
}
]
}
}
]
}
```
## Advanced
### ✨ Pass additional parameters to guardrail ### ✨ Pass additional parameters to guardrail
:::info :::info
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) ✨ This is an Enterprise only feature [Get a free trial](https://www.litellm.ai/#trial)
::: :::
@ -196,11 +231,40 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</Tabs> </Tabs>
## **Proxy Admin Controls**
### ✨ Monitoring Guardrails
Monitor which guardrails were executed and whether they passed or failed. e.g. guardrail going rogue and failing requests we don't intend to fail
:::info
✨ This is an Enterprise only feature [Get a free trial](https://www.litellm.ai/#trial)
:::
### Setup
1. Connect LiteLLM to a [supported logging provider](../logging)
2. Make a request with a `guardrails` parameter
3. Check your logging provider for the guardrail trace
#### Traced Guardrail Success
<Image img={require('../../../img/gd_success.png')} />
#### Traced Guardrail Failure
<Image img={require('../../../img/gd_fail.png')} />
### ✨ Control Guardrails per Project (API Key) ### ✨ Control Guardrails per Project (API Key)
:::info :::info
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) ✨ This is an Enterprise only feature [Get a free trial](https://www.litellm.ai/#trial)
::: :::
@ -262,7 +326,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
:::info :::info
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) ✨ This is an Enterprise only feature [Get a free trial](https://www.litellm.ai/#trial)
::: :::
@ -320,22 +384,6 @@ The `pii_masking` guardrail ran on this request because api key=sk-jNm1Zar7XfNdZ
### ✨ List guardrails
Show available guardrails on the proxy server. This makes it easier for developers to know what guardrails are available / can be used.
```shell
curl -X GET 'http://0.0.0.0:4000/guardrails/list'
```
Expected response
```json
{
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
}
```
## Spec: `guardrails` Parameter ## Spec: `guardrails` Parameter
The `guardrails` parameter can be passed to any LiteLLM Proxy endpoint (`/chat/completions`, `/completions`, `/embeddings`). The `guardrails` parameter can be passed to any LiteLLM Proxy endpoint (`/chat/completions`, `/completions`, `/embeddings`).

Binary file not shown.

After

Width:  |  Height:  |  Size: 380 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 229 KiB

View file

@ -1,8 +1,9 @@
from typing import Dict, List, Optional, Union from typing import Dict, List, Literal, Optional, Union
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
from litellm.types.utils import StandardLoggingGuardrailInformation
class CustomGuardrail(CustomLogger): class CustomGuardrail(CustomLogger):
@ -119,3 +120,101 @@ class CustomGuardrail(CustomLogger):
) )
return False return False
return True return True
def add_standard_logging_guardrail_information_to_request_data(
self,
guardrail_json_response: Union[Exception, str, dict],
request_data: dict,
guardrail_status: Literal["success", "failure"],
) -> None:
"""
Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
"""
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
verbose_logger.warning(
f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}"
)
return
if isinstance(guardrail_json_response, Exception):
guardrail_json_response = str(guardrail_json_response)
slg = StandardLoggingGuardrailInformation(
guardrail_name=self.guardrail_name,
guardrail_mode=self.event_hook,
guardrail_response=guardrail_json_response,
guardrail_status=guardrail_status,
)
if "metadata" in request_data:
request_data["metadata"]["standard_logging_guardrail_information"] = slg
elif "litellm_metadata" in request_data:
request_data["litellm_metadata"][
"standard_logging_guardrail_information"
] = slg
else:
verbose_logger.warning(
"unable to log guardrail information. No metadata found in request_data"
)
def log_guardrail_information(func):
"""
Decorator to add standard logging guardrail information to any function
Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc.
Logs for:
- pre_call
- during_call
- TODO: log post_call. This is more involved since the logs are sent to DD, s3 before the guardrail is even run
"""
import asyncio
import functools
def process_response(self, response, request_data):
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=response,
request_data=request_data,
guardrail_status="success",
)
return response
def process_error(self, e, request_data):
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=e,
request_data=request_data,
guardrail_status="failure",
)
raise e
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
self: CustomGuardrail = args[0]
request_data: Optional[dict] = (
kwargs.get("data") or kwargs.get("request_data") or {}
)
try:
response = await func(*args, **kwargs)
return process_response(self, response, request_data)
except Exception as e:
return process_error(self, e, request_data)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
self: CustomGuardrail = args[0]
request_data: Optional[dict] = (
kwargs.get("data") or kwargs.get("request_data") or {}
)
try:
response = func(*args, **kwargs)
return process_response(self, response, request_data)
except Exception as e:
return process_error(self, e, request_data)
@functools.wraps(func)
def wrapper(*args, **kwargs):
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
return sync_wrapper(*args, **kwargs)
return wrapper

View file

@ -3038,6 +3038,9 @@ def get_standard_logging_object_payload(
response_cost_failure_debug_info=kwargs.get( response_cost_failure_debug_info=kwargs.get(
"response_cost_failure_debug_information" "response_cost_failure_debug_information"
), ),
guardrail_information=metadata.get(
"standard_logging_guardrail_information", None
),
) )
return payload return payload

View file

@ -19,7 +19,10 @@ from fastapi import HTTPException
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.litellm_core_utils.logging_utils import ( from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_str, convert_litellm_response_object_to_str,
) )
@ -142,6 +145,7 @@ class AporiaGuardrail(CustomGuardrail):
}, },
) )
@log_guardrail_information
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict, data: dict,
@ -173,6 +177,7 @@ class AporiaGuardrail(CustomGuardrail):
pass pass
@log_guardrail_information
async def async_moderation_hook( ### 👈 KEY CHANGE ### async def async_moderation_hook( ### 👈 KEY CHANGE ###
self, self,
data: dict, data: dict,

View file

@ -19,7 +19,10 @@ from fastapi import HTTPException
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
@ -231,6 +234,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
response.text, response.text,
) )
@log_guardrail_information
async def async_moderation_hook( ### 👈 KEY CHANGE ### async def async_moderation_hook( ### 👈 KEY CHANGE ###
self, self,
data: dict, data: dict,
@ -263,6 +267,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
) )
pass pass
@log_guardrail_information
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict, data: dict,

View file

@ -3,7 +3,10 @@ from typing import Literal, Optional, Union
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
@ -17,6 +20,7 @@ class myCustomGuardrail(CustomGuardrail):
super().__init__(**kwargs) super().__init__(**kwargs)
@log_guardrail_information
async def async_pre_call_hook( async def async_pre_call_hook(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
@ -55,6 +59,7 @@ class myCustomGuardrail(CustomGuardrail):
return data return data
@log_guardrail_information
async def async_moderation_hook( async def async_moderation_hook(
self, self,
data: dict, data: dict,
@ -84,6 +89,7 @@ class myCustomGuardrail(CustomGuardrail):
if "litellm" in _content.lower(): if "litellm" in _content.lower():
raise ValueError("Guardrail failed words - `litellm` detected") raise ValueError("Guardrail failed words - `litellm` detected")
@log_guardrail_information
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict, data: dict,

View file

@ -12,7 +12,10 @@ from fastapi import HTTPException
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import ( from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response, get_content_from_model_response,
) )
@ -79,6 +82,7 @@ class GuardrailsAI(CustomGuardrail):
) )
return _json_response return _json_response
@log_guardrail_information
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict, data: dict,

View file

@ -20,7 +20,10 @@ from fastapi import HTTPException
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider, httpxSpecialProvider,
@ -294,6 +297,7 @@ class lakeraAI_Moderation(CustomGuardrail):
""" """
self._check_response_flagged(response=response.json()) self._check_response_flagged(response=response.json())
@log_guardrail_information
async def async_pre_call_hook( async def async_pre_call_hook(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
@ -330,6 +334,7 @@ class lakeraAI_Moderation(CustomGuardrail):
data=data, user_api_key_dict=user_api_key_dict, call_type=call_type data=data, user_api_key_dict=user_api_key_dict, call_type=call_type
) )
@log_guardrail_information
async def async_moderation_hook( ### 👈 KEY CHANGE ### async def async_moderation_hook( ### 👈 KEY CHANGE ###
self, self,
data: dict, data: dict,

View file

@ -20,7 +20,10 @@ import litellm # noqa: E401
from litellm import get_secret from litellm import get_secret
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.guardrails import GuardrailEventHooks from litellm.types.guardrails import GuardrailEventHooks
from litellm.utils import ( from litellm.utils import (
@ -205,6 +208,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
except Exception as e: except Exception as e:
raise e raise e
@log_guardrail_information
async def async_pre_call_hook( async def async_pre_call_hook(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
@ -257,6 +261,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
except Exception as e: except Exception as e:
raise e raise e
@log_guardrail_information
def logging_hook( def logging_hook(
self, kwargs: dict, result: Any, call_type: str self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]: ) -> Tuple[dict, Any]:
@ -289,6 +294,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
# No running event loop, we can safely run in this thread # No running event loop, we can safely run in this thread
return run_in_new_loop() return run_in_new_loop()
@log_guardrail_information
async def async_logging_hook( async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]: ) -> Tuple[dict, Any]:
@ -333,6 +339,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
return kwargs, result return kwargs, result
@log_guardrail_information
async def async_post_call_success_hook( # type: ignore async def async_post_call_success_hook( # type: ignore
self, self,
data: dict, data: dict,

View file

@ -11,11 +11,13 @@ model_list:
litellm_params: litellm_params:
model: bedrock/* model: bedrock/*
guardrails: guardrails:
- guardrail_name: "bedrock-pre-guard" - guardrail_name: "bedrock-pre-guard"
litellm_params: litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock
guardrailVersion: "DRAFT" # your guardrail version on bedrock
mode: "post_call" mode: "post_call"
guardrailIdentifier: ff6ujrregl1q guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT" guardrailVersion: "DRAFT"

View file

@ -105,9 +105,10 @@ class LitellmParams(TypedDict):
guard_name: Optional[str] guard_name: Optional[str]
class Guardrail(TypedDict): class Guardrail(TypedDict, total=False):
guardrail_name: str guardrail_name: str
litellm_params: LitellmParams litellm_params: LitellmParams
guardrail_info: Optional[Dict]
class guardrailConfig(TypedDict): class guardrailConfig(TypedDict):

View file

@ -21,6 +21,7 @@ from pydantic import BaseModel, ConfigDict, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
from .guardrails import GuardrailEventHooks
from .llms.openai import ( from .llms.openai import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
@ -1500,6 +1501,13 @@ class StandardLoggingPayloadErrorInformation(TypedDict, total=False):
llm_provider: Optional[str] llm_provider: Optional[str]
class StandardLoggingGuardrailInformation(TypedDict, total=False):
guardrail_name: Optional[str]
guardrail_mode: Optional[GuardrailEventHooks]
guardrail_response: Optional[Union[dict, str]]
guardrail_status: Literal["success", "failure"]
StandardLoggingPayloadStatus = Literal["success", "failure"] StandardLoggingPayloadStatus = Literal["success", "failure"]
@ -1539,6 +1547,7 @@ class StandardLoggingPayload(TypedDict):
error_information: Optional[StandardLoggingPayloadErrorInformation] error_information: Optional[StandardLoggingPayloadErrorInformation]
model_parameters: dict model_parameters: dict
hidden_params: StandardLoggingHiddenParams hidden_params: StandardLoggingHiddenParams
guardrail_information: Optional[StandardLoggingGuardrailInformation]
from typing import AsyncIterator, Iterator from typing import AsyncIterator, Iterator