feat(utils.py): add async success callbacks for custom functions

This commit is contained in:
Krrish Dholakia 2023-12-04 16:36:21 -08:00
parent b90fcbdac4
commit e0ccb281d8
8 changed files with 232 additions and 138 deletions

View file

@ -8,6 +8,7 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
set_verbose = False

View file

@ -8,7 +8,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class CustomLogger:
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self):
pass
@ -29,7 +29,7 @@ class CustomLogger:
pass
#### DEPRECATED ####
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try:
@ -63,3 +63,21 @@ class CustomLogger:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
await callback_func(
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
)
print_verbose(
f"Custom Logger - final response object: {response_obj}"
)
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass

View file

@ -272,9 +272,15 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list
print(f"master_key - {master_key}; api_key - {api_key}")
if master_key is None:
if isinstance(api_key, str):
return {
"api_key": None
"api_key": api_key.replace("Bearer ", "")
}
else:
return {
"api_key": api_key
}
try:
if api_key is None:
@ -382,8 +388,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
def cost_tracking():
global prisma_client, master_key
if prisma_client is not None and master_key is not None:
global prisma_client
if prisma_client is not None:
if isinstance(litellm.success_callback, list):
print("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
@ -391,7 +397,7 @@ def cost_tracking():
else:
litellm.success_callback = track_cost_callback # type: ignore
def track_cost_callback(
async def track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time = None,
@ -420,31 +426,13 @@ def track_cost_callback(
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
# asyncio.run(update_prisma_database(user_api_key, response_cost))
# Create new event loop for async function execution in the new thread
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
# Run the async function using the newly created event loop
existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
except Exception as e:
print(f"error in creating async loop - {str(e)}")
await update_prisma_database(token=user_api_key, response_cost=response_cost)
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost):
try:
print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token
@ -460,8 +448,6 @@ async def update_prisma_database(token, response_cost):
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
except Exception as e:
print(f"Error updating Prisma database: {traceback.format_exc()}")
pass
@ -648,7 +634,7 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List):
global prisma_client

View file

@ -876,6 +876,7 @@ class Router:
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
if "azure" in model_name:
self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, {str(api_base)}, {api_key}")
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
@ -913,6 +914,7 @@ class Router:
max_retries=max_retries
)
else:
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
model["async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,

View file

@ -1,5 +1,5 @@
### What this tests ####
import sys, os, time
import sys, os, time, inspect, asyncio
import pytest
sys.path.insert(0, os.path.abspath('../..'))
@ -7,6 +7,7 @@ from litellm import completion, embedding
import litellm
from litellm.integrations.custom_logger import CustomLogger
async_success = False
class MyCustomHandler(CustomLogger):
success: bool = False
failure: bool = False
@ -28,24 +29,29 @@ class MyCustomHandler(CustomLogger):
print(f"On Failure")
self.failure = True
# def test_chat_openai():
# try:
# customHandler = MyCustomHandler()
# litellm.callbacks = [customHandler]
# response = completion(model="gpt-3.5-turbo",
# messages=[{
# "role": "user",
# "content": "Hi 👋 - i'm openai"
# }],
# stream=True)
# time.sleep(1)
# assert customHandler.success == True
# except Exception as e:
# pytest.fail(f"An error occurred - {str(e)}")
# pass
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time):
global async_success
print(f"ON ASYNC LOGGING")
async_success = True
# test_chat_openai()
@pytest.mark.asyncio
async def test_chat_openai():
try:
# litellm.set_verbose = True
litellm.success_callback = [async_test_logging_fn]
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
continue
assert async_success == True
except Exception as e:
print(e)
pytest.fail(f"An error occurred - {str(e)}")
def test_completion_azure_stream_moderation_failure():
try:
@ -71,76 +77,3 @@ def test_completion_azure_stream_moderation_failure():
assert customHandler.failure == True
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_stream_moderation_failure()
# def custom_callback(
# kwargs,
# completion_response,
# start_time,
# end_time,
# ):
# print(
# "in custom callback func"
# )
# print("kwargs", kwargs)
# print(completion_response)
# print(start_time)
# print(end_time)
# if "complete_streaming_response" in kwargs:
# print("\n\n complete response\n\n")
# complete_streaming_response = kwargs["complete_streaming_response"]
# print(kwargs["complete_streaming_response"])
# usage = complete_streaming_response["usage"]
# print("usage", usage)
# def send_slack_alert(
# kwargs,
# completion_response,
# start_time,
# end_time,
# ):
# print(
# "in custom slack callback func"
# )
# import requests
# import json
# # Define the Slack webhook URL
# slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>"
# # Define the text payload, send data available in litellm custom_callbacks
# text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)}
# """
# payload = {
# "text": text_payload
# }
# # Set the headers
# headers = {
# "Content-type": "application/json"
# }
# # Make the POST request
# response = requests.post(slack_webhook_url, json=payload, headers=headers)
# # Check the response status
# if response.status_code == 200:
# print("Message sent successfully to Slack!")
# else:
# print(f"Failed to send message to Slack. Status code: {response.status_code}")
# print(response.json())
# def get_transformed_inputs(
# kwargs,
# ):
# params_to_model = kwargs["additional_args"]["complete_input_dict"]
# print("params to model", params_to_model)
# litellm.success_callback = [custom_callback, send_slack_alert]
# litellm.failure_callback = [send_slack_alert]
# litellm.set_verbose = False
# # litellm.input_callback = [get_transformed_inputs]

View file

@ -1,27 +1,138 @@
# #### What this tests ####
# # This tests the cost tracking function works with consecutive calls (~10 consecutive calls)
# import sys, os
# import sys, os, asyncio
# import traceback
# import pytest
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import dotenv
# dotenv.load_dotenv()
# import litellm
# from fastapi.testclient import TestClient
# from fastapi import FastAPI
# from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
# filepath = os.path.dirname(os.path.abspath(__file__))
# config_fp = f"{filepath}/test_config.yaml"
# save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
# app = FastAPI()
# app.include_router(router) # Include your router in the test app
# @app.on_event("startup")
# async def wrapper_startup_event():
# await startup_event()
# async def test_proxy_cost_tracking():
# # Here you create a fixture that will be used by your tests
# # Make sure the fixture returns TestClient(app)
# @pytest.fixture(autouse=True)
# def client():
# with TestClient(app) as client:
# yield client
# @pytest.mark.asyncio
# async def test_proxy_cost_tracking(client):
# """
# Get expected cost.
# Get min cost.
# Create new key.
# Run 10 parallel calls.
# Check cost for key at the end.
# assert it's = expected cost.
# assert it's > min cost.
# """
# model = "gpt-3.5-turbo"
# messages = [{"role": "user", "content": "Hey, how's it going?"}]
# number_of_calls = 10
# expected_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls
# async def litellm_acompletion():
# number_of_calls = 1
# min_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls
# try:
# ### CREATE NEW KEY ###
# test_data = {
# "models": ["azure-model"],
# }
# # Your bearer token
# token = os.getenv("PROXY_MASTER_KEY")
# headers = {
# "Authorization": f"Bearer {token}"
# }
# create_new_key = client.post("/key/generate", json=test_data, headers=headers)
# key = create_new_key.json()["key"]
# print(f"received key: {key}")
# ### MAKE PARALLEL CALLS ###
# async def test_chat_completions():
# # Your test data
# test_data = {
# "model": "azure-model",
# "messages": messages
# }
# tmp_headers = {
# "Authorization": f"Bearer {key}"
# }
# response = client.post("/v1/chat/completions", json=test_data, headers=tmp_headers)
# assert response.status_code == 200
# result = response.json()
# print(f"Received response: {result}")
# tasks = [test_chat_completions() for _ in range(number_of_calls)]
# chat_completions = await asyncio.gather(*tasks)
# ### CHECK SPEND ###
# get_key_spend = client.get(f"/key/info?key={key}", headers=headers)
# assert get_key_spend.json()["info"]["spend"] > min_cost
# # print(f"chat_completions: {chat_completions}")
# # except Exception as e:
# # pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
# #### JUST TEST LOCAL PROXY SERVER
# import requests, os
# from concurrent.futures import ThreadPoolExecutor
# import dotenv
# dotenv.load_dotenv()
# api_url = "http://0.0.0.0:8000/chat/completions"
# def make_api_call(api_url):
# # Your test data
# test_data = {
# "model": "azure-model",
# "messages": [
# {
# "role": "user",
# "content": "hi"
# },
# ],
# "max_tokens": 10,
# }
# # Your bearer token
# token = os.getenv("PROXY_MASTER_KEY")
# headers = {
# "Authorization": f"Bearer {token}"
# }
# print("testing proxy server")
# response = requests.post(api_url, json=test_data, headers=headers)
# return response.json()
# # Number of parallel API calls
# num_parallel_calls = 3
# # List to store results
# results = []
# # Create a ThreadPoolExecutor
# with ThreadPoolExecutor() as executor:
# # Submit the API calls concurrently
# futures = [executor.submit(make_api_call, api_url) for _ in range(num_parallel_calls)]
# # Gather the results as they become available
# for future in futures:
# try:
# result = future.result()
# results.append(result)
# except Exception as e:
# print(f"Error: {e}")
# # Print the results
# for idx, result in enumerate(results, start=1):
# print(f"Result {idx}: {result}")

View file

@ -59,6 +59,7 @@ def test_add_new_key(client):
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
assert result["key"].startswith("sk-")
print(f"Received response: {result}")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)

View file

@ -742,11 +742,7 @@ class Logging:
)
pass
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
print_verbose(
f"Logging Details LiteLLM-Success Call"
)
def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None):
try:
if start_time is None:
start_time = self.start_time
@ -776,6 +772,18 @@ class Logging:
float_diff = float(time_diff)
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
return start_time, end_time, result, complete_streaming_response
except:
pass
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
print_verbose(
f"Logging Details LiteLLM-Success Call"
)
try:
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
print_verbose(f"success callbacks: {litellm.success_callback}")
for callback in litellm.success_callback:
try:
if callback == "lite_debugger":
@ -969,6 +977,29 @@ class Logging:
)
pass
async def async_success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
print_verbose(f"success callbacks: {litellm.success_callback}")
for callback in litellm._async_success_callback:
try:
if callable(callback): # custom logger functions
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback
)
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
print_verbose(
f"Logging Details LiteLLM-Failure Call"
@ -1185,6 +1216,17 @@ def client(original_function):
callback_list=callback_list,
function_id=function_id
)
## ASYNC CALLBACKS
if len(litellm.success_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.success_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_success_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.success_callback.pop(index)
if add_breadcrumb:
add_breadcrumb(
category="litellm.llm_call",
@ -1373,7 +1415,6 @@ def client(original_function):
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
@ -1426,8 +1467,8 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
# LOG SUCCESS - handle streaming success logging in the _next_ object
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
# RETURN RESULT
if isinstance(result, ModelResponse):
@ -1465,7 +1506,6 @@ def client(original_function):
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
raise e
# Use httpx to determine if the original function is a coroutine
is_coroutine = inspect.iscoroutinefunction(original_function)
# Return the appropriate wrapper based on the original function type
@ -5370,6 +5410,8 @@ class CustomStreamWrapper:
processed_chunk = self.chunk_creator(chunk=chunk)
if processed_chunk is None:
continue
## LOGGING
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls