diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index af3ea050f8..f85a6a198a 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -3,6 +3,8 @@ import dotenv, os import requests import requests +import inspect +import asyncio dotenv.load_dotenv() # Loading env variables using dotenv import traceback @@ -50,16 +52,27 @@ class CustomLogger: # Method definition try: kwargs["log_event_type"] = "post_api_call" - callback_func( - kwargs, # kwargs to func - response_obj, - start_time, - end_time, - ) + if inspect.iscoroutinefunction(callback_func): + # If it's async, use asyncio to run it + + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(callback_func(kwargs, response_obj, start_time, end_time)) + else: + # If it's not async, run it synchronously + callback_func( + kwargs, # kwargs to func + response_obj, + start_time, + end_time, + ) print_verbose( f"Custom Logger - final response object: {response_obj}" ) - except: + except Exception as e: + raise e # traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index bb3011b3f3..3881cc80d2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -146,8 +146,44 @@ def usage_telemetry( api_key_header = APIKeyHeader(name="Authorization", auto_error=False) +async def track_cost_callback( + kwargs, # kwargs to completion + completion_response: litellm.ModelResponse, # response from completion + start_time = None, + end_time = None, # start/end time for completion +): + try: + # init logging config + print("in cost track callback") + api_key = kwargs["litellm_params"]["metadata"]["api_key"] + # check if it has collected an entire stream response + if "complete_streaming_response" in kwargs: + # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost + completion_response=kwargs["complete_streaming_response"] + input_text = kwargs["messages"] + output_text = completion_response["choices"][0]["message"]["content"] + response_cost = litellm.completion_cost( + model = kwargs["model"], + messages = input_text, + completion=output_text + ) + print(f"LiteLLM Proxy: streaming response_cost: {response_cost} for api_key: {api_key}") + # for non streaming responses + else: + # we pass the completion_response obj + if kwargs["stream"] != True: + response_cost = litellm.completion_cost(completion_response=completion_response) + print(f"\n LiteLLM Proxy: regular response_cost: {response_cost} for api_key: {api_key}") + + ########### write costs to DB api_key / cost map + await update_verification_token_cost(token=api_key, additional_cost=response_cost) + + except: + pass + async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): global master_key, prisma_client, llm_model_list + print("IN AUTH PRISMA CLIENT", prisma_client) if master_key is None: return try: @@ -275,12 +311,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ## Cost Tracking for master key + auth setup ## if master_key is not None: if isinstance(litellm.success_callback, list): - import utils print("setting litellm success callback to track cost") - if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore - litellm.success_callback.append(utils.track_cost_callback) # type: ignore + if (track_cost_callback) not in litellm.success_callback: # type: ignore + litellm.success_callback.append(track_cost_callback) # type: ignore else: - litellm.success_callback = utils.track_cost_callback # type: ignore + litellm.success_callback = track_cost_callback # type: ignore ### START REDIS QUEUE ### use_queue = general_settings.get("use_queue", False) celery_setup(use_queue=use_queue) @@ -386,6 +421,32 @@ async def delete_verification_token(tokens: List[str]): raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return deleted_tokens +async def update_verification_token_cost(token: str, additional_cost: float): + global prisma_client + print("in update verification token") + print("prisma client", prisma_client) + + try: + if prisma_client: + # Assuming 'db' is your Prisma Client instance + existing_token = await prisma_client.litellm_verificationtoken.find_unique(where={"token": token}) + print("existing token data", existing_token) + if existing_token: + old_cost = existing_token.get("cost", 0.0) + new_cost = old_cost + additional_cost + updated_token = await prisma_client.litellm_verificationtoken.update( + where={"token": token}, + data={"cost": new_cost} + ) + return updated_token + else: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") + else: + raise Exception + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + async def generate_key_cli_task(duration_str): task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) @@ -612,6 +673,12 @@ async def completion(request: Request, model: Optional[str] = None): async def chat_completion(request: Request, model: Optional[str] = None): global general_settings try: + bearer_api_key = request.headers.get("authorization") + print("beaerer key", bearer_api_key) + if "Bearer " in bearer_api_key: + cleaned_api_key = bearer_api_key[len("Bearer "):] + print("cleaned appi key", cleaned_api_key) + body = await request.body() body_str = body.decode() try: @@ -626,6 +693,9 @@ async def chat_completion(request: Request, model: Optional[str] = None): or data["model"] # default passed in http request ) data["call_type"] = "chat_completion" + if "metadata" not in data: + data["metadata"] = {} + data["metadata"] = {"api_key": cleaned_api_key} return litellm_completion( **data ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index c61bcfa85d..307357f835 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,38 +1,41 @@ -import litellm -from litellm import ModelResponse -from proxy_server import llm_model_list -from typing import Optional +# import litellm +# from litellm import ModelResponse +# from proxy_server import update_verification_token_cost +# from typing import Optional +# from fastapi import HTTPException, status +# import asyncio -def track_cost_callback( - kwargs, # kwargs to completion - completion_response: ModelResponse, # response from completion - start_time = None, - end_time = None, # start/end time for completion -): - try: - # init logging config - print("in custom callback tracking cost", llm_model_list) - if "azure" in kwargs["model"]: - # for azure cost tracking, we check the provided model list in the config.yaml - # we need to map azure/chatgpt-deployment to -> azure/gpt-3.5-turbo - pass - # check if it has collected an entire stream response - if "complete_streaming_response" in kwargs: - # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost - completion_response=kwargs["complete_streaming_response"] - input_text = kwargs["messages"] - output_text = completion_response["choices"][0]["message"]["content"] - response_cost = litellm.completion_cost( - model = kwargs["model"], - messages = input_text, - completion=output_text - ) - print("streaming response_cost", response_cost) - # for non streaming responses - else: - # we pass the completion_response obj - if kwargs["stream"] != True: - response_cost = litellm.completion_cost(completion_response=completion_response) - print("regular response_cost", response_cost) - except: - pass \ No newline at end of file +# def track_cost_callback( +# kwargs, # kwargs to completion +# completion_response: ModelResponse, # response from completion +# start_time = None, +# end_time = None, # start/end time for completion +# ): +# try: +# # init logging config +# api_key = kwargs["litellm_params"]["metadata"]["api_key"] +# # check if it has collected an entire stream response +# if "complete_streaming_response" in kwargs: +# # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost +# completion_response=kwargs["complete_streaming_response"] +# input_text = kwargs["messages"] +# output_text = completion_response["choices"][0]["message"]["content"] +# response_cost = litellm.completion_cost( +# model = kwargs["model"], +# messages = input_text, +# completion=output_text +# ) +# print(f"LiteLLM Proxy: streaming response_cost: {response_cost} for api_key: {api_key}") +# # for non streaming responses +# else: +# # we pass the completion_response obj +# if kwargs["stream"] != True: +# response_cost = litellm.completion_cost(completion_response=completion_response) +# print(f"\n LiteLLM Proxy: regular response_cost: {response_cost} for api_key: {api_key}") + +# ########### write costs to DB api_key / cost map +# asyncio.run( +# update_verification_token_cost(token=api_key, additional_cost=response_cost) +# ) +# except: +# pass \ No newline at end of file