diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md index ce34e5ad6..25a46609d 100644 --- a/docs/my-website/docs/proxy/call_hooks.md +++ b/docs/my-website/docs/proxy/call_hooks.md @@ -47,6 +47,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/enterprise/enterprise_hooks/aporio_ai.py b/enterprise/enterprise_hooks/aporio_ai.py index 6529ddcba..5d1b081fe 100644 --- a/enterprise/enterprise_hooks/aporio_ai.py +++ b/enterprise/enterprise_hooks/aporio_ai.py @@ -133,6 +133,7 @@ class _ENTERPRISE_Aporio(CustomLogger): async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py index 4d6545eb0..e282ee5ab 100644 --- a/enterprise/enterprise_hooks/banned_keywords.py +++ b/enterprise/enterprise_hooks/banned_keywords.py @@ -90,6 +90,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger): async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 98b0da25c..47d28ab56 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -122,6 +122,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/proxy/custom_callbacks1.py b/litellm/proxy/custom_callbacks1.py index 37e4a6cdb..05028f033 100644 --- a/litellm/proxy/custom_callbacks1.py +++ b/litellm/proxy/custom_callbacks1.py @@ -40,6 +40,7 @@ class MyCustomHandler( async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py index 972ac9992..ccadafaf2 100644 --- a/litellm/proxy/hooks/azure_content_safety.py +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -1,11 +1,16 @@ -from litellm.integrations.custom_logger import CustomLogger -from litellm.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth -import litellm, traceback, sys, uuid -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger +import sys +import traceback +import uuid from typing import Optional +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + class _PROXY_AzureContentSafety( CustomLogger @@ -15,12 +20,12 @@ class _PROXY_AzureContentSafety( def __init__(self, endpoint, api_key, thresholds=None): try: from azure.ai.contentsafety.aio import ContentSafetyClient - from azure.core.credentials import AzureKeyCredential from azure.ai.contentsafety.models import ( - TextCategory, AnalyzeTextOptions, AnalyzeTextOutputType, + TextCategory, ) + from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import HttpResponseError except Exception as e: raise Exception( @@ -132,6 +137,7 @@ class _PROXY_AzureContentSafety( async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 4bf08998a..57985e9a6 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -254,7 +254,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): return None async def async_post_call_success_hook( - self, user_api_key_dict: UserAPIKeyAuth, response + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response ): try: if isinstance(response, ModelResponse): @@ -287,7 +287,9 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): return response return await super().async_post_call_success_hook( - user_api_key_dict, response + data=data, + user_api_key_dict=user_api_key_dict, + response=response, ) except Exception as e: verbose_proxy_logger.exception( diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py index 933d92550..6af7e3d1e 100644 --- a/litellm/proxy/hooks/presidio_pii_masking.py +++ b/litellm/proxy/hooks/presidio_pii_masking.py @@ -322,6 +322,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response: Union[ModelResponse, EmbeddingResponse, ImageResponse], ): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 659dc350f..adecb2eb6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3136,7 +3136,7 @@ async def chat_completion( ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( - user_api_key_dict=user_api_key_dict, response=response + data=data, user_api_key_dict=user_api_key_dict, response=response ) hidden_params = ( @@ -3350,6 +3350,11 @@ async def completion( media_type="text/event-stream", headers=custom_headers, ) + ### CALL HOOKS ### - modify outgoing data + response = await proxy_logging_obj.post_call_success_hook( + data=data, user_api_key_dict=user_api_key_dict, response=response + ) + fastapi_response.headers.update( get_custom_headers( user_api_key_dict=user_api_key_dict, diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d1d17d0ef..94862db9a 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -717,6 +717,7 @@ class ProxyLogging: async def post_call_success_hook( self, + data: dict, response: Union[ModelResponse, EmbeddingResponse, ImageResponse], user_api_key_dict: UserAPIKeyAuth, ): @@ -738,7 +739,9 @@ class ProxyLogging: _callback = callback # type: ignore if _callback is not None and isinstance(_callback, CustomLogger): await _callback.async_post_call_success_hook( - user_api_key_dict=user_api_key_dict, response=response + user_api_key_dict=user_api_key_dict, + data=data, + response=response, ) except Exception as e: raise e diff --git a/litellm/tests/test_azure_content_safety.py b/litellm/tests/test_azure_content_safety.py index 7b040fb25..dc80c163c 100644 --- a/litellm/tests/test_azure_content_safety.py +++ b/litellm/tests/test_azure_content_safety.py @@ -1,8 +1,13 @@ # What is this? ## Unit test for azure content safety -import sys, os, asyncio, time, random -from datetime import datetime +import asyncio +import os +import random +import sys +import time import traceback +from datetime import datetime + from dotenv import load_dotenv from fastapi import HTTPException @@ -13,11 +18,12 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest + import litellm from litellm import Router, mock_completion -from litellm.proxy.utils import ProxyLogging -from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.utils import ProxyLogging @pytest.mark.asyncio @@ -177,7 +183,13 @@ async def test_strict_output_filtering_01(): with pytest.raises(HTTPException) as exc_info: await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [ + {"role": "system", "content": "You are an helpfull assistant"} + ] + }, + response=response, ) assert exc_info.value.detail["source"] == "output" @@ -216,7 +228,11 @@ async def test_strict_output_filtering_02(): ) await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) @@ -251,7 +267,11 @@ async def test_loose_output_filtering_01(): ) await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) @@ -286,5 +306,9 @@ async def test_loose_output_filtering_02(): ) await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) diff --git a/litellm/tests/test_presidio_masking.py b/litellm/tests/test_presidio_masking.py index 193fcf113..35a03ea5e 100644 --- a/litellm/tests/test_presidio_masking.py +++ b/litellm/tests/test_presidio_masking.py @@ -88,7 +88,11 @@ async def test_output_parsing(): mock_response="Hello ! How can I assist you today?", ) new_response = await pii_masking.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) assert (