diff --git a/litellm/__init__.py b/litellm/__init__.py index b494268ad5..b9cf85a55e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -8,6 +8,7 @@ input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] +_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] set_verbose = False diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index af3ea050f8..e502439a95 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -8,7 +8,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv import traceback -class CustomLogger: +class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class # Class variables or attributes def __init__(self): pass @@ -29,7 +29,7 @@ class CustomLogger: pass - #### DEPRECATED #### + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): try: @@ -63,3 +63,21 @@ class CustomLogger: # traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") pass + + async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): + # Method definition + try: + kwargs["log_event_type"] = "post_api_call" + await callback_func( + kwargs, # kwargs to func + response_obj, + start_time, + end_time, + ) + print_verbose( + f"Custom Logger - final response object: {response_obj}" + ) + except: + # traceback.print_exc() + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ae3ec5298e..8e9ddc9fa2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -272,10 +272,16 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False) async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): global master_key, prisma_client, llm_model_list + print(f"master_key - {master_key}; api_key - {api_key}") if master_key is None: - return { - "api_key": None - } + if isinstance(api_key, str): + return { + "api_key": api_key.replace("Bearer ", "") + } + else: + return { + "api_key": api_key + } try: if api_key is None: raise Exception("No api key passed in.") @@ -382,8 +388,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`") def cost_tracking(): - global prisma_client, master_key - if prisma_client is not None and master_key is not None: + global prisma_client + if prisma_client is not None: if isinstance(litellm.success_callback, list): print("setting litellm success callback to track cost") if (track_cost_callback) not in litellm.success_callback: # type: ignore @@ -391,7 +397,7 @@ def cost_tracking(): else: litellm.success_callback = track_cost_callback # type: ignore -def track_cost_callback( +async def track_cost_callback( kwargs, # kwargs to completion completion_response: litellm.ModelResponse, # response from completion start_time = None, @@ -420,31 +426,13 @@ def track_cost_callback( response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text) print("regular response_cost", response_cost) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) + print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") if user_api_key and prisma_client: - # asyncio.run(update_prisma_database(user_api_key, response_cost)) - # Create new event loop for async function execution in the new thread - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - try: - # Run the async function using the newly created event loop - existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key)) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - print(f"new cost: {new_spend}") - # Update the cost column for the given token - new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend})) - print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}") - except Exception as e: - print(f"error in creating async loop - {str(e)}") + await update_prisma_database(token=user_api_key, response_cost=response_cost) except Exception as e: print(f"error in tracking cost callback - {str(e)}") async def update_prisma_database(token, response_cost): - try: print(f"Enters prisma db call, token: {token}") # Fetch the existing cost for the given token @@ -460,8 +448,6 @@ async def update_prisma_database(token, response_cost): print(f"new cost: {new_spend}") # Update the cost column for the given token await prisma_client.update_data(token=token, data={"spend": new_spend}) - print(f"Prisma database updated for token {token}. New cost: {new_spend}") - except Exception as e: print(f"Error updating Prisma database: {traceback.format_exc()}") pass @@ -648,7 +634,7 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id} + return {"token": token, "expires": new_verification_token.expires, "user_id": user_id} async def delete_verification_token(tokens: List): global prisma_client diff --git a/litellm/router.py b/litellm/router.py index 5ce0d409bb..5bf06760e6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -876,6 +876,7 @@ class Router: self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") if "azure" in model_name: + self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, {str(api_base)}, {api_key}") if api_version is None: api_version = "2023-07-01-preview" if "gateway.ai.cloudflare.com" in api_base: @@ -913,6 +914,7 @@ class Router: max_retries=max_retries ) else: + self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") model["async_client"] = openai.AsyncOpenAI( api_key=api_key, base_url=api_base, diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 7e134bd261..f88bc6868d 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -1,5 +1,5 @@ ### What this tests #### -import sys, os, time +import sys, os, time, inspect, asyncio import pytest sys.path.insert(0, os.path.abspath('../..')) @@ -7,6 +7,7 @@ from litellm import completion, embedding import litellm from litellm.integrations.custom_logger import CustomLogger +async_success = False class MyCustomHandler(CustomLogger): success: bool = False failure: bool = False @@ -28,24 +29,29 @@ class MyCustomHandler(CustomLogger): print(f"On Failure") self.failure = True -# def test_chat_openai(): -# try: -# customHandler = MyCustomHandler() -# litellm.callbacks = [customHandler] -# response = completion(model="gpt-3.5-turbo", -# messages=[{ -# "role": "user", -# "content": "Hi 👋 - i'm openai" -# }], -# stream=True) -# time.sleep(1) -# assert customHandler.success == True -# except Exception as e: -# pytest.fail(f"An error occurred - {str(e)}") -# pass +async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time): + global async_success + print(f"ON ASYNC LOGGING") + async_success = True -# test_chat_openai() +@pytest.mark.asyncio +async def test_chat_openai(): + try: + # litellm.set_verbose = True + litellm.success_callback = [async_test_logging_fn] + response = await litellm.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + stream=True) + async for chunk in response: + continue + assert async_success == True + except Exception as e: + print(e) + pytest.fail(f"An error occurred - {str(e)}") def test_completion_azure_stream_moderation_failure(): try: @@ -71,76 +77,3 @@ def test_completion_azure_stream_moderation_failure(): assert customHandler.failure == True except Exception as e: pytest.fail(f"Error occurred: {e}") - -# test_completion_azure_stream_moderation_failure() - - -# def custom_callback( -# kwargs, -# completion_response, -# start_time, -# end_time, -# ): -# print( -# "in custom callback func" -# ) -# print("kwargs", kwargs) -# print(completion_response) -# print(start_time) -# print(end_time) -# if "complete_streaming_response" in kwargs: -# print("\n\n complete response\n\n") -# complete_streaming_response = kwargs["complete_streaming_response"] -# print(kwargs["complete_streaming_response"]) -# usage = complete_streaming_response["usage"] -# print("usage", usage) -# def send_slack_alert( -# kwargs, -# completion_response, -# start_time, -# end_time, -# ): -# print( -# "in custom slack callback func" -# ) -# import requests -# import json - -# # Define the Slack webhook URL -# slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>" - -# # Define the text payload, send data available in litellm custom_callbacks -# text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)} -# """ -# payload = { -# "text": text_payload -# } - -# # Set the headers -# headers = { -# "Content-type": "application/json" -# } - -# # Make the POST request -# response = requests.post(slack_webhook_url, json=payload, headers=headers) - -# # Check the response status -# if response.status_code == 200: -# print("Message sent successfully to Slack!") -# else: -# print(f"Failed to send message to Slack. Status code: {response.status_code}") -# print(response.json()) - -# def get_transformed_inputs( -# kwargs, -# ): -# params_to_model = kwargs["additional_args"]["complete_input_dict"] -# print("params to model", params_to_model) - -# litellm.success_callback = [custom_callback, send_slack_alert] -# litellm.failure_callback = [send_slack_alert] - - -# litellm.set_verbose = False - -# # litellm.input_callback = [get_transformed_inputs] diff --git a/litellm/tests/test_proxy_server_cost.py b/litellm/tests/test_proxy_server_cost.py index 7688e58995..b127e72e3b 100644 --- a/litellm/tests/test_proxy_server_cost.py +++ b/litellm/tests/test_proxy_server_cost.py @@ -1,27 +1,138 @@ # #### What this tests #### # # This tests the cost tracking function works with consecutive calls (~10 consecutive calls) -# import sys, os +# import sys, os, asyncio # import traceback # import pytest # sys.path.insert( # 0, os.path.abspath("../..") # ) # Adds the parent directory to the system path +# import dotenv +# dotenv.load_dotenv() # import litellm +# from fastapi.testclient import TestClient +# from fastapi import FastAPI +# from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +# filepath = os.path.dirname(os.path.abspath(__file__)) +# config_fp = f"{filepath}/test_config.yaml" +# save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +# app = FastAPI() +# app.include_router(router) # Include your router in the test app +# @app.on_event("startup") +# async def wrapper_startup_event(): +# await startup_event() -# async def test_proxy_cost_tracking(): +# # Here you create a fixture that will be used by your tests +# # Make sure the fixture returns TestClient(app) +# @pytest.fixture(autouse=True) +# def client(): +# with TestClient(app) as client: +# yield client + +# @pytest.mark.asyncio +# async def test_proxy_cost_tracking(client): # """ -# Get expected cost. +# Get min cost. # Create new key. # Run 10 parallel calls. # Check cost for key at the end. -# assert it's = expected cost. +# assert it's > min cost. # """ # model = "gpt-3.5-turbo" # messages = [{"role": "user", "content": "Hey, how's it going?"}] -# number_of_calls = 10 -# expected_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls -# async def litellm_acompletion(): +# number_of_calls = 1 +# min_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls +# try: +# ### CREATE NEW KEY ### +# test_data = { +# "models": ["azure-model"], +# } +# # Your bearer token +# token = os.getenv("PROXY_MASTER_KEY") +# headers = { +# "Authorization": f"Bearer {token}" +# } +# create_new_key = client.post("/key/generate", json=test_data, headers=headers) +# key = create_new_key.json()["key"] +# print(f"received key: {key}") +# ### MAKE PARALLEL CALLS ### +# async def test_chat_completions(): +# # Your test data +# test_data = { +# "model": "azure-model", +# "messages": messages +# } +# tmp_headers = { +# "Authorization": f"Bearer {key}" +# } +# response = client.post("/v1/chat/completions", json=test_data, headers=tmp_headers) + +# assert response.status_code == 200 +# result = response.json() +# print(f"Received response: {result}") +# tasks = [test_chat_completions() for _ in range(number_of_calls)] +# chat_completions = await asyncio.gather(*tasks) +# ### CHECK SPEND ### +# get_key_spend = client.get(f"/key/info?key={key}", headers=headers) + +# assert get_key_spend.json()["info"]["spend"] > min_cost +# # print(f"chat_completions: {chat_completions}") +# # except Exception as e: +# # pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + +# #### JUST TEST LOCAL PROXY SERVER + +# import requests, os +# from concurrent.futures import ThreadPoolExecutor +# import dotenv +# dotenv.load_dotenv() + +# api_url = "http://0.0.0.0:8000/chat/completions" + +# def make_api_call(api_url): +# # Your test data +# test_data = { +# "model": "azure-model", +# "messages": [ +# { +# "role": "user", +# "content": "hi" +# }, +# ], +# "max_tokens": 10, +# } +# # Your bearer token +# token = os.getenv("PROXY_MASTER_KEY") + +# headers = { +# "Authorization": f"Bearer {token}" +# } +# print("testing proxy server") +# response = requests.post(api_url, json=test_data, headers=headers) +# return response.json() + +# # Number of parallel API calls +# num_parallel_calls = 3 + +# # List to store results +# results = [] + +# # Create a ThreadPoolExecutor +# with ThreadPoolExecutor() as executor: +# # Submit the API calls concurrently +# futures = [executor.submit(make_api_call, api_url) for _ in range(num_parallel_calls)] + +# # Gather the results as they become available +# for future in futures: +# try: +# result = future.result() +# results.append(result) +# except Exception as e: +# print(f"Error: {e}") + +# # Print the results +# for idx, result in enumerate(results, start=1): +# print(f"Result {idx}: {result}") diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index 806b5f43eb..a2dd396c05 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -59,6 +59,7 @@ def test_add_new_key(client): print(f"response: {response.text}") assert response.status_code == 200 result = response.json() + assert result["key"].startswith("sk-") print(f"Received response: {result}") except Exception as e: pytest.fail("LiteLLM Proxy test failed. Exception", e) diff --git a/litellm/utils.py b/litellm/utils.py index d3a9b8bb07..892fc010c1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -741,13 +741,9 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) pass - - def success_handler(self, result=None, start_time=None, end_time=None, **kwargs): - print_verbose( - f"Logging Details LiteLLM-Success Call" - ) - try: + def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None): + try: if start_time is None: start_time = self.start_time if end_time is None: @@ -776,6 +772,18 @@ class Logging: float_diff = float(time_diff) litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff) + return start_time, end_time, result, complete_streaming_response + except: + pass + + def success_handler(self, result=None, start_time=None, end_time=None, **kwargs): + print_verbose( + f"Logging Details LiteLLM-Success Call" + ) + try: + start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) + print_verbose(f"success callbacks: {litellm.success_callback}") + for callback in litellm.success_callback: try: if callback == "lite_debugger": @@ -969,6 +977,29 @@ class Logging: ) pass + async def async_success_handler(self, result=None, start_time=None, end_time=None, **kwargs): + """ + Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + """ + start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) + print_verbose(f"success callbacks: {litellm.success_callback}") + + for callback in litellm._async_success_callback: + try: + if callable(callback): # custom logger functions + await customLogger.async_log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback + ) + except: + print_verbose( + f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" + ) + def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): print_verbose( f"Logging Details LiteLLM-Failure Call" @@ -1185,6 +1216,17 @@ def client(original_function): callback_list=callback_list, function_id=function_id ) + ## ASYNC CALLBACKS + if len(litellm.success_callback) > 0: + removed_async_items = [] + for index, callback in enumerate(litellm.success_callback): + if inspect.iscoroutinefunction(callback): + litellm._async_success_callback.append(callback) + removed_async_items.append(index) + + # Pop the async items from success_callback in reverse order to avoid index issues + for index in reversed(removed_async_items): + litellm.success_callback.pop(index) if add_breadcrumb: add_breadcrumb( category="litellm.llm_call", @@ -1373,7 +1415,6 @@ def client(original_function): start_time = datetime.datetime.now() result = None logging_obj = kwargs.get("litellm_logging_obj", None) - # only set litellm_call_id if its not in kwargs if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) @@ -1426,8 +1467,8 @@ def client(original_function): # [OPTIONAL] ADD TO CACHE if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object litellm.cache.add_cache(result, *args, **kwargs) - - # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated + # LOG SUCCESS - handle streaming success logging in the _next_ object + asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time)) threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # RETURN RESULT if isinstance(result, ModelResponse): @@ -1465,7 +1506,6 @@ def client(original_function): logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! raise e - # Use httpx to determine if the original function is a coroutine is_coroutine = inspect.iscoroutinefunction(original_function) # Return the appropriate wrapper based on the original function type @@ -5370,6 +5410,8 @@ class CustomStreamWrapper: processed_chunk = self.chunk_creator(chunk=chunk) if processed_chunk is None: continue + ## LOGGING + asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) return processed_chunk raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls