forked from phoenix/litellm-mirror
feat(utils.py): add async success callbacks for custom functions
This commit is contained in:
parent
b90fcbdac4
commit
e0ccb281d8
8 changed files with 232 additions and 138 deletions
|
@ -8,6 +8,7 @@ input_callback: List[Union[str, Callable]] = []
|
||||||
success_callback: List[Union[str, Callable]] = []
|
success_callback: List[Union[str, Callable]] = []
|
||||||
failure_callback: List[Union[str, Callable]] = []
|
failure_callback: List[Union[str, Callable]] = []
|
||||||
callbacks: List[Callable] = []
|
callbacks: List[Callable] = []
|
||||||
|
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
set_verbose = False
|
set_verbose = False
|
||||||
|
|
|
@ -8,7 +8,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
class CustomLogger:
|
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -29,7 +29,7 @@ class CustomLogger:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
#### DEPRECATED ####
|
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||||
|
|
||||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||||
try:
|
try:
|
||||||
|
@ -63,3 +63,21 @@ class CustomLogger:
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
||||||
|
# Method definition
|
||||||
|
try:
|
||||||
|
kwargs["log_event_type"] = "post_api_call"
|
||||||
|
await callback_func(
|
||||||
|
kwargs, # kwargs to func
|
||||||
|
response_obj,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Custom Logger - final response object: {response_obj}"
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
# traceback.print_exc()
|
||||||
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
|
pass
|
||||||
|
|
|
@ -272,9 +272,15 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
||||||
global master_key, prisma_client, llm_model_list
|
global master_key, prisma_client, llm_model_list
|
||||||
|
print(f"master_key - {master_key}; api_key - {api_key}")
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
|
if isinstance(api_key, str):
|
||||||
return {
|
return {
|
||||||
"api_key": None
|
"api_key": api_key.replace("Bearer ", "")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"api_key": api_key
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
|
@ -382,8 +388,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
|
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
|
||||||
|
|
||||||
def cost_tracking():
|
def cost_tracking():
|
||||||
global prisma_client, master_key
|
global prisma_client
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None:
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
print("setting litellm success callback to track cost")
|
print("setting litellm success callback to track cost")
|
||||||
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||||
|
@ -391,7 +397,7 @@ def cost_tracking():
|
||||||
else:
|
else:
|
||||||
litellm.success_callback = track_cost_callback # type: ignore
|
litellm.success_callback = track_cost_callback # type: ignore
|
||||||
|
|
||||||
def track_cost_callback(
|
async def track_cost_callback(
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
completion_response: litellm.ModelResponse, # response from completion
|
completion_response: litellm.ModelResponse, # response from completion
|
||||||
start_time = None,
|
start_time = None,
|
||||||
|
@ -420,31 +426,13 @@ def track_cost_callback(
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||||
print("regular response_cost", response_cost)
|
print("regular response_cost", response_cost)
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
|
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and prisma_client:
|
||||||
# asyncio.run(update_prisma_database(user_api_key, response_cost))
|
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||||
# Create new event loop for async function execution in the new thread
|
|
||||||
new_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(new_loop)
|
|
||||||
try:
|
|
||||||
# Run the async function using the newly created event loop
|
|
||||||
existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
|
|
||||||
if existing_spend_obj is None:
|
|
||||||
existing_spend = 0
|
|
||||||
else:
|
|
||||||
existing_spend = existing_spend_obj.spend
|
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
|
||||||
new_spend = existing_spend + response_cost
|
|
||||||
print(f"new cost: {new_spend}")
|
|
||||||
# Update the cost column for the given token
|
|
||||||
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
|
|
||||||
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"error in creating async loop - {str(e)}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error in tracking cost callback - {str(e)}")
|
print(f"error in tracking cost callback - {str(e)}")
|
||||||
|
|
||||||
async def update_prisma_database(token, response_cost):
|
async def update_prisma_database(token, response_cost):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"Enters prisma db call, token: {token}")
|
print(f"Enters prisma db call, token: {token}")
|
||||||
# Fetch the existing cost for the given token
|
# Fetch the existing cost for the given token
|
||||||
|
@ -460,8 +448,6 @@ async def update_prisma_database(token, response_cost):
|
||||||
print(f"new cost: {new_spend}")
|
print(f"new cost: {new_spend}")
|
||||||
# Update the cost column for the given token
|
# Update the cost column for the given token
|
||||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||||
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error updating Prisma database: {traceback.format_exc()}")
|
print(f"Error updating Prisma database: {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
@ -648,7 +634,7 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
|
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
||||||
|
|
||||||
async def delete_verification_token(tokens: List):
|
async def delete_verification_token(tokens: List):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
|
|
@ -876,6 +876,7 @@ class Router:
|
||||||
|
|
||||||
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
||||||
if "azure" in model_name:
|
if "azure" in model_name:
|
||||||
|
self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, {str(api_base)}, {api_key}")
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
api_version = "2023-07-01-preview"
|
api_version = "2023-07-01-preview"
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
|
@ -913,6 +914,7 @@ class Router:
|
||||||
max_retries=max_retries
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
||||||
model["async_client"] = openai.AsyncOpenAI(
|
model["async_client"] = openai.AsyncOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### What this tests ####
|
### What this tests ####
|
||||||
import sys, os, time
|
import sys, os, time, inspect, asyncio
|
||||||
import pytest
|
import pytest
|
||||||
sys.path.insert(0, os.path.abspath('../..'))
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from litellm import completion, embedding
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
async_success = False
|
||||||
class MyCustomHandler(CustomLogger):
|
class MyCustomHandler(CustomLogger):
|
||||||
success: bool = False
|
success: bool = False
|
||||||
failure: bool = False
|
failure: bool = False
|
||||||
|
@ -28,24 +29,29 @@ class MyCustomHandler(CustomLogger):
|
||||||
print(f"On Failure")
|
print(f"On Failure")
|
||||||
self.failure = True
|
self.failure = True
|
||||||
|
|
||||||
# def test_chat_openai():
|
|
||||||
# try:
|
|
||||||
# customHandler = MyCustomHandler()
|
|
||||||
# litellm.callbacks = [customHandler]
|
|
||||||
# response = completion(model="gpt-3.5-turbo",
|
|
||||||
# messages=[{
|
|
||||||
# "role": "user",
|
|
||||||
# "content": "Hi 👋 - i'm openai"
|
|
||||||
# }],
|
|
||||||
# stream=True)
|
|
||||||
# time.sleep(1)
|
|
||||||
# assert customHandler.success == True
|
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"An error occurred - {str(e)}")
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time):
|
||||||
|
global async_success
|
||||||
|
print(f"ON ASYNC LOGGING")
|
||||||
|
async_success = True
|
||||||
|
|
||||||
# test_chat_openai()
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_openai():
|
||||||
|
try:
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
litellm.success_callback = [async_test_logging_fn]
|
||||||
|
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
assert async_success == True
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pytest.fail(f"An error occurred - {str(e)}")
|
||||||
|
|
||||||
def test_completion_azure_stream_moderation_failure():
|
def test_completion_azure_stream_moderation_failure():
|
||||||
try:
|
try:
|
||||||
|
@ -71,76 +77,3 @@ def test_completion_azure_stream_moderation_failure():
|
||||||
assert customHandler.failure == True
|
assert customHandler.failure == True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_azure_stream_moderation_failure()
|
|
||||||
|
|
||||||
|
|
||||||
# def custom_callback(
|
|
||||||
# kwargs,
|
|
||||||
# completion_response,
|
|
||||||
# start_time,
|
|
||||||
# end_time,
|
|
||||||
# ):
|
|
||||||
# print(
|
|
||||||
# "in custom callback func"
|
|
||||||
# )
|
|
||||||
# print("kwargs", kwargs)
|
|
||||||
# print(completion_response)
|
|
||||||
# print(start_time)
|
|
||||||
# print(end_time)
|
|
||||||
# if "complete_streaming_response" in kwargs:
|
|
||||||
# print("\n\n complete response\n\n")
|
|
||||||
# complete_streaming_response = kwargs["complete_streaming_response"]
|
|
||||||
# print(kwargs["complete_streaming_response"])
|
|
||||||
# usage = complete_streaming_response["usage"]
|
|
||||||
# print("usage", usage)
|
|
||||||
# def send_slack_alert(
|
|
||||||
# kwargs,
|
|
||||||
# completion_response,
|
|
||||||
# start_time,
|
|
||||||
# end_time,
|
|
||||||
# ):
|
|
||||||
# print(
|
|
||||||
# "in custom slack callback func"
|
|
||||||
# )
|
|
||||||
# import requests
|
|
||||||
# import json
|
|
||||||
|
|
||||||
# # Define the Slack webhook URL
|
|
||||||
# slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>"
|
|
||||||
|
|
||||||
# # Define the text payload, send data available in litellm custom_callbacks
|
|
||||||
# text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)}
|
|
||||||
# """
|
|
||||||
# payload = {
|
|
||||||
# "text": text_payload
|
|
||||||
# }
|
|
||||||
|
|
||||||
# # Set the headers
|
|
||||||
# headers = {
|
|
||||||
# "Content-type": "application/json"
|
|
||||||
# }
|
|
||||||
|
|
||||||
# # Make the POST request
|
|
||||||
# response = requests.post(slack_webhook_url, json=payload, headers=headers)
|
|
||||||
|
|
||||||
# # Check the response status
|
|
||||||
# if response.status_code == 200:
|
|
||||||
# print("Message sent successfully to Slack!")
|
|
||||||
# else:
|
|
||||||
# print(f"Failed to send message to Slack. Status code: {response.status_code}")
|
|
||||||
# print(response.json())
|
|
||||||
|
|
||||||
# def get_transformed_inputs(
|
|
||||||
# kwargs,
|
|
||||||
# ):
|
|
||||||
# params_to_model = kwargs["additional_args"]["complete_input_dict"]
|
|
||||||
# print("params to model", params_to_model)
|
|
||||||
|
|
||||||
# litellm.success_callback = [custom_callback, send_slack_alert]
|
|
||||||
# litellm.failure_callback = [send_slack_alert]
|
|
||||||
|
|
||||||
|
|
||||||
# litellm.set_verbose = False
|
|
||||||
|
|
||||||
# # litellm.input_callback = [get_transformed_inputs]
|
|
||||||
|
|
|
@ -1,27 +1,138 @@
|
||||||
# #### What this tests ####
|
# #### What this tests ####
|
||||||
# # This tests the cost tracking function works with consecutive calls (~10 consecutive calls)
|
# # This tests the cost tracking function works with consecutive calls (~10 consecutive calls)
|
||||||
|
|
||||||
# import sys, os
|
# import sys, os, asyncio
|
||||||
# import traceback
|
# import traceback
|
||||||
# import pytest
|
# import pytest
|
||||||
# sys.path.insert(
|
# sys.path.insert(
|
||||||
# 0, os.path.abspath("../..")
|
# 0, os.path.abspath("../..")
|
||||||
# ) # Adds the parent directory to the system path
|
# ) # Adds the parent directory to the system path
|
||||||
|
# import dotenv
|
||||||
|
# dotenv.load_dotenv()
|
||||||
# import litellm
|
# import litellm
|
||||||
|
# from fastapi.testclient import TestClient
|
||||||
|
# from fastapi import FastAPI
|
||||||
|
# from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||||
|
# filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# config_fp = f"{filepath}/test_config.yaml"
|
||||||
|
# save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
|
# app = FastAPI()
|
||||||
|
# app.include_router(router) # Include your router in the test app
|
||||||
|
# @app.on_event("startup")
|
||||||
|
# async def wrapper_startup_event():
|
||||||
|
# await startup_event()
|
||||||
|
|
||||||
# async def test_proxy_cost_tracking():
|
# # Here you create a fixture that will be used by your tests
|
||||||
|
# # Make sure the fixture returns TestClient(app)
|
||||||
|
# @pytest.fixture(autouse=True)
|
||||||
|
# def client():
|
||||||
|
# with TestClient(app) as client:
|
||||||
|
# yield client
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_proxy_cost_tracking(client):
|
||||||
# """
|
# """
|
||||||
# Get expected cost.
|
# Get min cost.
|
||||||
# Create new key.
|
# Create new key.
|
||||||
# Run 10 parallel calls.
|
# Run 10 parallel calls.
|
||||||
# Check cost for key at the end.
|
# Check cost for key at the end.
|
||||||
# assert it's = expected cost.
|
# assert it's > min cost.
|
||||||
# """
|
# """
|
||||||
# model = "gpt-3.5-turbo"
|
# model = "gpt-3.5-turbo"
|
||||||
# messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
# messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
# number_of_calls = 10
|
# number_of_calls = 1
|
||||||
# expected_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls
|
# min_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls
|
||||||
# async def litellm_acompletion():
|
# try:
|
||||||
|
# ### CREATE NEW KEY ###
|
||||||
|
# test_data = {
|
||||||
|
# "models": ["azure-model"],
|
||||||
|
# }
|
||||||
|
# # Your bearer token
|
||||||
|
# token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
# headers = {
|
||||||
|
# "Authorization": f"Bearer {token}"
|
||||||
|
# }
|
||||||
|
# create_new_key = client.post("/key/generate", json=test_data, headers=headers)
|
||||||
|
# key = create_new_key.json()["key"]
|
||||||
|
# print(f"received key: {key}")
|
||||||
|
# ### MAKE PARALLEL CALLS ###
|
||||||
|
# async def test_chat_completions():
|
||||||
|
# # Your test data
|
||||||
|
# test_data = {
|
||||||
|
# "model": "azure-model",
|
||||||
|
# "messages": messages
|
||||||
|
# }
|
||||||
|
|
||||||
|
# tmp_headers = {
|
||||||
|
# "Authorization": f"Bearer {key}"
|
||||||
|
# }
|
||||||
|
|
||||||
|
# response = client.post("/v1/chat/completions", json=test_data, headers=tmp_headers)
|
||||||
|
|
||||||
|
# assert response.status_code == 200
|
||||||
|
# result = response.json()
|
||||||
|
# print(f"Received response: {result}")
|
||||||
|
# tasks = [test_chat_completions() for _ in range(number_of_calls)]
|
||||||
|
# chat_completions = await asyncio.gather(*tasks)
|
||||||
|
# ### CHECK SPEND ###
|
||||||
|
# get_key_spend = client.get(f"/key/info?key={key}", headers=headers)
|
||||||
|
|
||||||
|
# assert get_key_spend.json()["info"]["spend"] > min_cost
|
||||||
|
# # print(f"chat_completions: {chat_completions}")
|
||||||
|
# # except Exception as e:
|
||||||
|
# # pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
|
# #### JUST TEST LOCAL PROXY SERVER
|
||||||
|
|
||||||
|
# import requests, os
|
||||||
|
# from concurrent.futures import ThreadPoolExecutor
|
||||||
|
# import dotenv
|
||||||
|
# dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# api_url = "http://0.0.0.0:8000/chat/completions"
|
||||||
|
|
||||||
|
# def make_api_call(api_url):
|
||||||
|
# # Your test data
|
||||||
|
# test_data = {
|
||||||
|
# "model": "azure-model",
|
||||||
|
# "messages": [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "hi"
|
||||||
|
# },
|
||||||
|
# ],
|
||||||
|
# "max_tokens": 10,
|
||||||
|
# }
|
||||||
|
# # Your bearer token
|
||||||
|
# token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
# headers = {
|
||||||
|
# "Authorization": f"Bearer {token}"
|
||||||
|
# }
|
||||||
|
# print("testing proxy server")
|
||||||
|
# response = requests.post(api_url, json=test_data, headers=headers)
|
||||||
|
# return response.json()
|
||||||
|
|
||||||
|
# # Number of parallel API calls
|
||||||
|
# num_parallel_calls = 3
|
||||||
|
|
||||||
|
# # List to store results
|
||||||
|
# results = []
|
||||||
|
|
||||||
|
# # Create a ThreadPoolExecutor
|
||||||
|
# with ThreadPoolExecutor() as executor:
|
||||||
|
# # Submit the API calls concurrently
|
||||||
|
# futures = [executor.submit(make_api_call, api_url) for _ in range(num_parallel_calls)]
|
||||||
|
|
||||||
|
# # Gather the results as they become available
|
||||||
|
# for future in futures:
|
||||||
|
# try:
|
||||||
|
# result = future.result()
|
||||||
|
# results.append(result)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error: {e}")
|
||||||
|
|
||||||
|
# # Print the results
|
||||||
|
# for idx, result in enumerate(results, start=1):
|
||||||
|
# print(f"Result {idx}: {result}")
|
||||||
|
|
|
@ -59,6 +59,7 @@ def test_add_new_key(client):
|
||||||
print(f"response: {response.text}")
|
print(f"response: {response.text}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
assert result["key"].startswith("sk-")
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
||||||
|
|
|
@ -742,11 +742,7 @@ class Logging:
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None):
|
||||||
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
|
|
||||||
print_verbose(
|
|
||||||
f"Logging Details LiteLLM-Success Call"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
if start_time is None:
|
if start_time is None:
|
||||||
start_time = self.start_time
|
start_time = self.start_time
|
||||||
|
@ -776,6 +772,18 @@ class Logging:
|
||||||
float_diff = float(time_diff)
|
float_diff = float(time_diff)
|
||||||
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
|
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
|
||||||
|
|
||||||
|
return start_time, end_time, result, complete_streaming_response
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
|
||||||
|
print_verbose(
|
||||||
|
f"Logging Details LiteLLM-Success Call"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
|
||||||
|
print_verbose(f"success callbacks: {litellm.success_callback}")
|
||||||
|
|
||||||
for callback in litellm.success_callback:
|
for callback in litellm.success_callback:
|
||||||
try:
|
try:
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger":
|
||||||
|
@ -969,6 +977,29 @@ class Logging:
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def async_success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||||
|
"""
|
||||||
|
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
|
||||||
|
print_verbose(f"success callbacks: {litellm.success_callback}")
|
||||||
|
|
||||||
|
for callback in litellm._async_success_callback:
|
||||||
|
try:
|
||||||
|
if callable(callback): # custom logger functions
|
||||||
|
await customLogger.async_log_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
callback_func=callback
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
print_verbose(
|
||||||
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Logging Details LiteLLM-Failure Call"
|
f"Logging Details LiteLLM-Failure Call"
|
||||||
|
@ -1185,6 +1216,17 @@ def client(original_function):
|
||||||
callback_list=callback_list,
|
callback_list=callback_list,
|
||||||
function_id=function_id
|
function_id=function_id
|
||||||
)
|
)
|
||||||
|
## ASYNC CALLBACKS
|
||||||
|
if len(litellm.success_callback) > 0:
|
||||||
|
removed_async_items = []
|
||||||
|
for index, callback in enumerate(litellm.success_callback):
|
||||||
|
if inspect.iscoroutinefunction(callback):
|
||||||
|
litellm._async_success_callback.append(callback)
|
||||||
|
removed_async_items.append(index)
|
||||||
|
|
||||||
|
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||||
|
for index in reversed(removed_async_items):
|
||||||
|
litellm.success_callback.pop(index)
|
||||||
if add_breadcrumb:
|
if add_breadcrumb:
|
||||||
add_breadcrumb(
|
add_breadcrumb(
|
||||||
category="litellm.llm_call",
|
category="litellm.llm_call",
|
||||||
|
@ -1373,7 +1415,6 @@ def client(original_function):
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
result = None
|
result = None
|
||||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||||
|
|
||||||
# only set litellm_call_id if its not in kwargs
|
# only set litellm_call_id if its not in kwargs
|
||||||
if "litellm_call_id" not in kwargs:
|
if "litellm_call_id" not in kwargs:
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
|
@ -1426,8 +1467,8 @@ def client(original_function):
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
||||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
if isinstance(result, ModelResponse):
|
if isinstance(result, ModelResponse):
|
||||||
|
@ -1465,7 +1506,6 @@ def client(original_function):
|
||||||
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
|
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Use httpx to determine if the original function is a coroutine
|
|
||||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||||
|
|
||||||
# Return the appropriate wrapper based on the original function type
|
# Return the appropriate wrapper based on the original function type
|
||||||
|
@ -5370,6 +5410,8 @@ class CustomStreamWrapper:
|
||||||
processed_chunk = self.chunk_creator(chunk=chunk)
|
processed_chunk = self.chunk_creator(chunk=chunk)
|
||||||
if processed_chunk is None:
|
if processed_chunk is None:
|
||||||
continue
|
continue
|
||||||
|
## LOGGING
|
||||||
|
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
else: # temporary patch for non-aiohttp async calls
|
else: # temporary patch for non-aiohttp async calls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue