feat - allow accessing data post success call

This commit is contained in:
Ishaan Jaff 2024-08-19 11:35:33 -07:00
parent 6af497e383
commit b4bca8db82
12 changed files with 71 additions and 21 deletions

View file

@ -47,6 +47,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -133,6 +133,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -90,6 +90,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -122,6 +122,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -40,6 +40,7 @@ class MyCustomHandler(
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -1,11 +1,16 @@
from litellm.integrations.custom_logger import CustomLogger import sys
from litellm.caching import DualCache import traceback
from litellm.proxy._types import UserAPIKeyAuth import uuid
import litellm, traceback, sys, uuid
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from typing import Optional from typing import Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_AzureContentSafety( class _PROXY_AzureContentSafety(
CustomLogger CustomLogger
@ -15,12 +20,12 @@ class _PROXY_AzureContentSafety(
def __init__(self, endpoint, api_key, thresholds=None): def __init__(self, endpoint, api_key, thresholds=None):
try: try:
from azure.ai.contentsafety.aio import ContentSafetyClient from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import ( from azure.ai.contentsafety.models import (
TextCategory,
AnalyzeTextOptions, AnalyzeTextOptions,
AnalyzeTextOutputType, AnalyzeTextOutputType,
TextCategory,
) )
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError from azure.core.exceptions import HttpResponseError
except Exception as e: except Exception as e:
raise Exception( raise Exception(
@ -132,6 +137,7 @@ class _PROXY_AzureContentSafety(
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -254,7 +254,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return None return None
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, user_api_key_dict: UserAPIKeyAuth, response self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
): ):
try: try:
if isinstance(response, ModelResponse): if isinstance(response, ModelResponse):
@ -287,7 +287,9 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return response return response
return await super().async_post_call_success_hook( return await super().async_post_call_success_hook(
user_api_key_dict, response data=data,
user_api_key_dict=user_api_key_dict,
response=response,
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(

View file

@ -322,6 +322,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
): ):

View file

@ -3136,7 +3136,7 @@ async def chat_completion(
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response data=data, user_api_key_dict=user_api_key_dict, response=response
) )
hidden_params = ( hidden_params = (
@ -3350,6 +3350,11 @@ async def completion(
media_type="text/event-stream", media_type="text/event-stream",
headers=custom_headers, headers=custom_headers,
) )
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
fastapi_response.headers.update( fastapi_response.headers.update(
get_custom_headers( get_custom_headers(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,

View file

@ -717,6 +717,7 @@ class ProxyLogging:
async def post_call_success_hook( async def post_call_success_hook(
self, self,
data: dict,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
): ):
@ -738,7 +739,9 @@ class ProxyLogging:
_callback = callback # type: ignore _callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger): if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook( await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response user_api_key_dict=user_api_key_dict,
data=data,
response=response,
) )
except Exception as e: except Exception as e:
raise e raise e

View file

@ -1,8 +1,13 @@
# What is this? # What is this?
## Unit test for azure content safety ## Unit test for azure content safety
import sys, os, asyncio, time, random import asyncio
from datetime import datetime import os
import random
import sys
import time
import traceback import traceback
from datetime import datetime
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import HTTPException from fastapi import HTTPException
@ -13,11 +18,12 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from litellm import Router, mock_completion from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.utils import ProxyLogging
@pytest.mark.asyncio @pytest.mark.asyncio
@ -177,7 +183,13 @@ async def test_strict_output_filtering_01():
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [
{"role": "system", "content": "You are an helpfull assistant"}
]
},
response=response,
) )
assert exc_info.value.detail["source"] == "output" assert exc_info.value.detail["source"] == "output"
@ -216,7 +228,11 @@ async def test_strict_output_filtering_02():
) )
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )
@ -251,7 +267,11 @@ async def test_loose_output_filtering_01():
) )
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )
@ -286,5 +306,9 @@ async def test_loose_output_filtering_02():
) )
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )

View file

@ -88,7 +88,11 @@ async def test_output_parsing():
mock_response="Hello <PERSON>! How can I assist you today?", mock_response="Hello <PERSON>! How can I assist you today?",
) )
new_response = await pii_masking.async_post_call_success_hook( new_response = await pii_masking.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )
assert ( assert (