mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Support post-call guards for stream and non-stream responses
This commit is contained in:
parent
be35c9a663
commit
4a31b32a88
8 changed files with 297 additions and 33 deletions
|
@ -18,6 +18,7 @@ from litellm.proxy._types import (
|
|||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
try:
|
||||
import backoff
|
||||
|
@ -31,7 +32,7 @@ from fastapi import HTTPException, status
|
|||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router
|
||||
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||
from litellm.caching.caching import DualCache, RedisCache
|
||||
|
@ -972,7 +973,7 @@ class ProxyLogging:
|
|||
1. /chat/completions
|
||||
"""
|
||||
response_str: Optional[str] = None
|
||||
if isinstance(response, ModelResponse):
|
||||
if isinstance(response, (ModelResponse, ModelResponseStream)):
|
||||
response_str = litellm.get_response_string(response_obj=response)
|
||||
if response_str is not None:
|
||||
for callback in litellm.callbacks:
|
||||
|
@ -992,6 +993,35 @@ class ProxyLogging:
|
|||
raise e
|
||||
return response
|
||||
|
||||
def async_post_call_streaming_iterator_hook(
|
||||
self,
|
||||
response,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
request_data: dict,
|
||||
):
|
||||
"""
|
||||
Allow user to modify outgoing streaming data -> Given a whole response iterator.
|
||||
This hook is best used when you need to modify multiple chunks of the response at once.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
"""
|
||||
for callback in litellm.callbacks:
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail(
|
||||
data=request_data, event_type=GuardrailEventHooks.post_call
|
||||
):
|
||||
response = _callback.async_post_call_streaming_iterator_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response, request_data=request_data
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def post_call_streaming_hook(
|
||||
self,
|
||||
response: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue