add cost tracking

This commit is contained in:
ishaan-jaff 2023-11-24 14:38:05 -08:00
parent 41c6388185
commit ff7ed23db3
3 changed files with 134 additions and 48 deletions

View file

@ -3,6 +3,8 @@
import dotenv, os import dotenv, os
import requests import requests
import requests import requests
import inspect
import asyncio
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -50,6 +52,16 @@ class CustomLogger:
# Method definition # Method definition
try: try:
kwargs["log_event_type"] = "post_api_call" kwargs["log_event_type"] = "post_api_call"
if inspect.iscoroutinefunction(callback_func):
# If it's async, use asyncio to run it
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(callback_func(kwargs, response_obj, start_time, end_time))
else:
# If it's not async, run it synchronously
callback_func( callback_func(
kwargs, # kwargs to func kwargs, # kwargs to func
response_obj, response_obj,
@ -59,7 +71,8 @@ class CustomLogger:
print_verbose( print_verbose(
f"Custom Logger - final response object: {response_obj}" f"Custom Logger - final response object: {response_obj}"
) )
except: except Exception as e:
raise e
# traceback.print_exc() # traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass

View file

@ -146,8 +146,44 @@ def usage_telemetry(
api_key_header = APIKeyHeader(name="Authorization", auto_error=False) api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time = None,
end_time = None, # start/end time for completion
):
try:
# init logging config
print("in cost track callback")
api_key = kwargs["litellm_params"]["metadata"]["api_key"]
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
print(f"LiteLLM Proxy: streaming response_cost: {response_cost} for api_key: {api_key}")
# for non streaming responses
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
response_cost = litellm.completion_cost(completion_response=completion_response)
print(f"\n LiteLLM Proxy: regular response_cost: {response_cost} for api_key: {api_key}")
########### write costs to DB api_key / cost map
await update_verification_token_cost(token=api_key, additional_cost=response_cost)
except:
pass
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list global master_key, prisma_client, llm_model_list
print("IN AUTH PRISMA CLIENT", prisma_client)
if master_key is None: if master_key is None:
return return
try: try:
@ -275,12 +311,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
## Cost Tracking for master key + auth setup ## ## Cost Tracking for master key + auth setup ##
if master_key is not None: if master_key is not None:
if isinstance(litellm.success_callback, list): if isinstance(litellm.success_callback, list):
import utils
print("setting litellm success callback to track cost") print("setting litellm success callback to track cost")
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore if (track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(utils.track_cost_callback) # type: ignore litellm.success_callback.append(track_cost_callback) # type: ignore
else: else:
litellm.success_callback = utils.track_cost_callback # type: ignore litellm.success_callback = track_cost_callback # type: ignore
### START REDIS QUEUE ### ### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False) use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue) celery_setup(use_queue=use_queue)
@ -386,6 +421,32 @@ async def delete_verification_token(tokens: List[str]):
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return deleted_tokens return deleted_tokens
async def update_verification_token_cost(token: str, additional_cost: float):
global prisma_client
print("in update verification token")
print("prisma client", prisma_client)
try:
if prisma_client:
# Assuming 'db' is your Prisma Client instance
existing_token = await prisma_client.litellm_verificationtoken.find_unique(where={"token": token})
print("existing token data", existing_token)
if existing_token:
old_cost = existing_token.get("cost", 0.0)
new_cost = old_cost + additional_cost
updated_token = await prisma_client.litellm_verificationtoken.update(
where={"token": token},
data={"cost": new_cost}
)
return updated_token
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found")
else:
raise Exception
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
async def generate_key_cli_task(duration_str): async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
@ -612,6 +673,12 @@ async def completion(request: Request, model: Optional[str] = None):
async def chat_completion(request: Request, model: Optional[str] = None): async def chat_completion(request: Request, model: Optional[str] = None):
global general_settings global general_settings
try: try:
bearer_api_key = request.headers.get("authorization")
print("beaerer key", bearer_api_key)
if "Bearer " in bearer_api_key:
cleaned_api_key = bearer_api_key[len("Bearer "):]
print("cleaned appi key", cleaned_api_key)
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
try: try:
@ -626,6 +693,9 @@ async def chat_completion(request: Request, model: Optional[str] = None):
or data["model"] # default passed in http request or data["model"] # default passed in http request
) )
data["call_type"] = "chat_completion" data["call_type"] = "chat_completion"
if "metadata" not in data:
data["metadata"] = {}
data["metadata"] = {"api_key": cleaned_api_key}
return litellm_completion( return litellm_completion(
**data **data
) )

View file

@ -1,38 +1,41 @@
import litellm # import litellm
from litellm import ModelResponse # from litellm import ModelResponse
from proxy_server import llm_model_list # from proxy_server import update_verification_token_cost
from typing import Optional # from typing import Optional
# from fastapi import HTTPException, status
# import asyncio
def track_cost_callback( # def track_cost_callback(
kwargs, # kwargs to completion # kwargs, # kwargs to completion
completion_response: ModelResponse, # response from completion # completion_response: ModelResponse, # response from completion
start_time = None, # start_time = None,
end_time = None, # start/end time for completion # end_time = None, # start/end time for completion
): # ):
try: # try:
# init logging config # # init logging config
print("in custom callback tracking cost", llm_model_list) # api_key = kwargs["litellm_params"]["metadata"]["api_key"]
if "azure" in kwargs["model"]: # # check if it has collected an entire stream response
# for azure cost tracking, we check the provided model list in the config.yaml # if "complete_streaming_response" in kwargs:
# we need to map azure/chatgpt-deployment to -> azure/gpt-3.5-turbo # # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
pass # completion_response=kwargs["complete_streaming_response"]
# check if it has collected an entire stream response # input_text = kwargs["messages"]
if "complete_streaming_response" in kwargs: # output_text = completion_response["choices"][0]["message"]["content"]
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost # response_cost = litellm.completion_cost(
completion_response=kwargs["complete_streaming_response"] # model = kwargs["model"],
input_text = kwargs["messages"] # messages = input_text,
output_text = completion_response["choices"][0]["message"]["content"] # completion=output_text
response_cost = litellm.completion_cost( # )
model = kwargs["model"], # print(f"LiteLLM Proxy: streaming response_cost: {response_cost} for api_key: {api_key}")
messages = input_text, # # for non streaming responses
completion=output_text # else:
) # # we pass the completion_response obj
print("streaming response_cost", response_cost) # if kwargs["stream"] != True:
# for non streaming responses # response_cost = litellm.completion_cost(completion_response=completion_response)
else: # print(f"\n LiteLLM Proxy: regular response_cost: {response_cost} for api_key: {api_key}")
# we pass the completion_response obj
if kwargs["stream"] != True: # ########### write costs to DB api_key / cost map
response_cost = litellm.completion_cost(completion_response=completion_response) # asyncio.run(
print("regular response_cost", response_cost) # update_verification_token_cost(token=api_key, additional_cost=response_cost)
except: # )
pass # except:
# pass