Support post-call guards for stream and non-stream responses

This commit is contained in:
Tomer Bin 2025-01-26 12:28:22 +02:00
parent be35c9a663
commit 4a31b32a88
8 changed files with 297 additions and 33 deletions

View file

@ -18,6 +18,7 @@ from litellm.proxy._types import (
ProxyErrorTypes,
ProxyException,
)
from litellm.types.guardrails import GuardrailEventHooks
try:
import backoff
@ -31,7 +32,7 @@ from fastapi import HTTPException, status
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream
from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching.caching import DualCache, RedisCache
@ -972,7 +973,7 @@ class ProxyLogging:
1. /chat/completions
"""
response_str: Optional[str] = None
if isinstance(response, ModelResponse):
if isinstance(response, (ModelResponse, ModelResponseStream)):
response_str = litellm.get_response_string(response_obj=response)
if response_str is not None:
for callback in litellm.callbacks:
@ -992,6 +993,35 @@ class ProxyLogging:
raise e
return response
def async_post_call_streaming_iterator_hook(
self,
response,
user_api_key_dict: UserAPIKeyAuth,
request_data: dict,
):
"""
Allow user to modify outgoing streaming data -> Given a whole response iterator.
This hook is best used when you need to modify multiple chunks of the response at once.
Covers:
1. /chat/completions
"""
for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None
if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call
):
response = _callback.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict, response=response, request_data=request_data
)
return response
async def post_call_streaming_hook(
self,
response: str,