feat - allow accessing data post success call

This commit is contained in:
Ishaan Jaff 2024-08-19 11:35:33 -07:00
parent 8cb62213e1
commit 4685b9909a
12 changed files with 71 additions and 21 deletions

View file

@ -47,6 +47,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):

View file

@ -133,6 +133,7 @@ class _ENTERPRISE_Aporio(CustomLogger):
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):

View file

@ -90,6 +90,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):

View file

@ -122,6 +122,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):

View file

@ -40,6 +40,7 @@ class MyCustomHandler(
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):

View file

@ -1,11 +1,16 @@
from litellm.integrations.custom_logger import CustomLogger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
import litellm, traceback, sys, uuid
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
import sys
import traceback
import uuid
from typing import Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_AzureContentSafety(
CustomLogger
@ -15,12 +20,12 @@ class _PROXY_AzureContentSafety(
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import (
TextCategory,
AnalyzeTextOptions,
AnalyzeTextOutputType,
TextCategory,
)
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
@ -132,6 +137,7 @@ class _PROXY_AzureContentSafety(
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):

View file

@ -254,7 +254,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return None
async def async_post_call_success_hook(
self, user_api_key_dict: UserAPIKeyAuth, response
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
try:
if isinstance(response, ModelResponse):
@ -287,7 +287,9 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return response
return await super().async_post_call_success_hook(
user_api_key_dict, response
data=data,
user_api_key_dict=user_api_key_dict,
response=response,
)
except Exception as e:
verbose_proxy_logger.exception(

View file

@ -322,6 +322,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
):

View file

@ -3136,7 +3136,7 @@ async def chat_completion(
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
data=data, user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = (
@ -3350,6 +3350,11 @@ async def completion(
media_type="text/event-stream",
headers=custom_headers,
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,

View file

@ -717,6 +717,7 @@ class ProxyLogging:
async def post_call_success_hook(
self,
data: dict,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
user_api_key_dict: UserAPIKeyAuth,
):
@ -738,7 +739,9 @@ class ProxyLogging:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
except Exception as e:
raise e

View file

@ -1,8 +1,13 @@
# What is this?
## Unit test for azure content safety
import sys, os, asyncio, time, random
from datetime import datetime
import asyncio
import os
import random
import sys
import time
import traceback
from datetime import datetime
from dotenv import load_dotenv
from fastapi import HTTPException
@ -13,11 +18,12 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.utils import ProxyLogging
@pytest.mark.asyncio
@ -177,7 +183,13 @@ async def test_strict_output_filtering_01():
with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [
{"role": "system", "content": "You are an helpfull assistant"}
]
},
response=response,
)
assert exc_info.value.detail["source"] == "output"
@ -216,7 +228,11 @@ async def test_strict_output_filtering_02():
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
)
@ -251,7 +267,11 @@ async def test_loose_output_filtering_01():
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
)
@ -286,5 +306,9 @@ async def test_loose_output_filtering_02():
)
await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
)

View file

@ -88,7 +88,11 @@ async def test_output_parsing():
mock_response="Hello <PERSON>! How can I assist you today?",
)
new_response = await pii_masking.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
)
assert (