mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
add cost tracking
This commit is contained in:
parent
41c6388185
commit
ff7ed23db3
3 changed files with 134 additions and 48 deletions
|
@ -3,6 +3,8 @@
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
import requests
|
import requests
|
||||||
|
import inspect
|
||||||
|
import asyncio
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -50,6 +52,16 @@ class CustomLogger:
|
||||||
# Method definition
|
# Method definition
|
||||||
try:
|
try:
|
||||||
kwargs["log_event_type"] = "post_api_call"
|
kwargs["log_event_type"] = "post_api_call"
|
||||||
|
if inspect.iscoroutinefunction(callback_func):
|
||||||
|
# If it's async, use asyncio to run it
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_closed():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(callback_func(kwargs, response_obj, start_time, end_time))
|
||||||
|
else:
|
||||||
|
# If it's not async, run it synchronously
|
||||||
callback_func(
|
callback_func(
|
||||||
kwargs, # kwargs to func
|
kwargs, # kwargs to func
|
||||||
response_obj,
|
response_obj,
|
||||||
|
@ -59,7 +71,8 @@ class CustomLogger:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Custom Logger - final response object: {response_obj}"
|
f"Custom Logger - final response object: {response_obj}"
|
||||||
)
|
)
|
||||||
except:
|
except Exception as e:
|
||||||
|
raise e
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -146,8 +146,44 @@ def usage_telemetry(
|
||||||
|
|
||||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
|
|
||||||
|
async def track_cost_callback(
|
||||||
|
kwargs, # kwargs to completion
|
||||||
|
completion_response: litellm.ModelResponse, # response from completion
|
||||||
|
start_time = None,
|
||||||
|
end_time = None, # start/end time for completion
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# init logging config
|
||||||
|
print("in cost track callback")
|
||||||
|
api_key = kwargs["litellm_params"]["metadata"]["api_key"]
|
||||||
|
# check if it has collected an entire stream response
|
||||||
|
if "complete_streaming_response" in kwargs:
|
||||||
|
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
|
completion_response=kwargs["complete_streaming_response"]
|
||||||
|
input_text = kwargs["messages"]
|
||||||
|
output_text = completion_response["choices"][0]["message"]["content"]
|
||||||
|
response_cost = litellm.completion_cost(
|
||||||
|
model = kwargs["model"],
|
||||||
|
messages = input_text,
|
||||||
|
completion=output_text
|
||||||
|
)
|
||||||
|
print(f"LiteLLM Proxy: streaming response_cost: {response_cost} for api_key: {api_key}")
|
||||||
|
# for non streaming responses
|
||||||
|
else:
|
||||||
|
# we pass the completion_response obj
|
||||||
|
if kwargs["stream"] != True:
|
||||||
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
|
print(f"\n LiteLLM Proxy: regular response_cost: {response_cost} for api_key: {api_key}")
|
||||||
|
|
||||||
|
########### write costs to DB api_key / cost map
|
||||||
|
await update_verification_token_cost(token=api_key, additional_cost=response_cost)
|
||||||
|
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
||||||
global master_key, prisma_client, llm_model_list
|
global master_key, prisma_client, llm_model_list
|
||||||
|
print("IN AUTH PRISMA CLIENT", prisma_client)
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
@ -275,12 +311,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
## Cost Tracking for master key + auth setup ##
|
## Cost Tracking for master key + auth setup ##
|
||||||
if master_key is not None:
|
if master_key is not None:
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
import utils
|
|
||||||
print("setting litellm success callback to track cost")
|
print("setting litellm success callback to track cost")
|
||||||
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
|
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||||
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
|
litellm.success_callback.append(track_cost_callback) # type: ignore
|
||||||
else:
|
else:
|
||||||
litellm.success_callback = utils.track_cost_callback # type: ignore
|
litellm.success_callback = track_cost_callback # type: ignore
|
||||||
### START REDIS QUEUE ###
|
### START REDIS QUEUE ###
|
||||||
use_queue = general_settings.get("use_queue", False)
|
use_queue = general_settings.get("use_queue", False)
|
||||||
celery_setup(use_queue=use_queue)
|
celery_setup(use_queue=use_queue)
|
||||||
|
@ -386,6 +421,32 @@ async def delete_verification_token(tokens: List[str]):
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return deleted_tokens
|
return deleted_tokens
|
||||||
|
|
||||||
|
async def update_verification_token_cost(token: str, additional_cost: float):
|
||||||
|
global prisma_client
|
||||||
|
print("in update verification token")
|
||||||
|
print("prisma client", prisma_client)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if prisma_client:
|
||||||
|
# Assuming 'db' is your Prisma Client instance
|
||||||
|
existing_token = await prisma_client.litellm_verificationtoken.find_unique(where={"token": token})
|
||||||
|
print("existing token data", existing_token)
|
||||||
|
if existing_token:
|
||||||
|
old_cost = existing_token.get("cost", 0.0)
|
||||||
|
new_cost = old_cost + additional_cost
|
||||||
|
updated_token = await prisma_client.litellm_verificationtoken.update(
|
||||||
|
where={"token": token},
|
||||||
|
data={"cost": new_cost}
|
||||||
|
)
|
||||||
|
return updated_token
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found")
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
async def generate_key_cli_task(duration_str):
|
async def generate_key_cli_task(duration_str):
|
||||||
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
||||||
|
@ -612,6 +673,12 @@ async def completion(request: Request, model: Optional[str] = None):
|
||||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
global general_settings
|
global general_settings
|
||||||
try:
|
try:
|
||||||
|
bearer_api_key = request.headers.get("authorization")
|
||||||
|
print("beaerer key", bearer_api_key)
|
||||||
|
if "Bearer " in bearer_api_key:
|
||||||
|
cleaned_api_key = bearer_api_key[len("Bearer "):]
|
||||||
|
print("cleaned appi key", cleaned_api_key)
|
||||||
|
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
try:
|
try:
|
||||||
|
@ -626,6 +693,9 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
data["call_type"] = "chat_completion"
|
data["call_type"] = "chat_completion"
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"] = {"api_key": cleaned_api_key}
|
||||||
return litellm_completion(
|
return litellm_completion(
|
||||||
**data
|
**data
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,38 +1,41 @@
|
||||||
import litellm
|
# import litellm
|
||||||
from litellm import ModelResponse
|
# from litellm import ModelResponse
|
||||||
from proxy_server import llm_model_list
|
# from proxy_server import update_verification_token_cost
|
||||||
from typing import Optional
|
# from typing import Optional
|
||||||
|
# from fastapi import HTTPException, status
|
||||||
|
# import asyncio
|
||||||
|
|
||||||
def track_cost_callback(
|
# def track_cost_callback(
|
||||||
kwargs, # kwargs to completion
|
# kwargs, # kwargs to completion
|
||||||
completion_response: ModelResponse, # response from completion
|
# completion_response: ModelResponse, # response from completion
|
||||||
start_time = None,
|
# start_time = None,
|
||||||
end_time = None, # start/end time for completion
|
# end_time = None, # start/end time for completion
|
||||||
):
|
# ):
|
||||||
try:
|
# try:
|
||||||
# init logging config
|
# # init logging config
|
||||||
print("in custom callback tracking cost", llm_model_list)
|
# api_key = kwargs["litellm_params"]["metadata"]["api_key"]
|
||||||
if "azure" in kwargs["model"]:
|
# # check if it has collected an entire stream response
|
||||||
# for azure cost tracking, we check the provided model list in the config.yaml
|
# if "complete_streaming_response" in kwargs:
|
||||||
# we need to map azure/chatgpt-deployment to -> azure/gpt-3.5-turbo
|
# # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
pass
|
# completion_response=kwargs["complete_streaming_response"]
|
||||||
# check if it has collected an entire stream response
|
# input_text = kwargs["messages"]
|
||||||
if "complete_streaming_response" in kwargs:
|
# output_text = completion_response["choices"][0]["message"]["content"]
|
||||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
# response_cost = litellm.completion_cost(
|
||||||
completion_response=kwargs["complete_streaming_response"]
|
# model = kwargs["model"],
|
||||||
input_text = kwargs["messages"]
|
# messages = input_text,
|
||||||
output_text = completion_response["choices"][0]["message"]["content"]
|
# completion=output_text
|
||||||
response_cost = litellm.completion_cost(
|
# )
|
||||||
model = kwargs["model"],
|
# print(f"LiteLLM Proxy: streaming response_cost: {response_cost} for api_key: {api_key}")
|
||||||
messages = input_text,
|
# # for non streaming responses
|
||||||
completion=output_text
|
# else:
|
||||||
)
|
# # we pass the completion_response obj
|
||||||
print("streaming response_cost", response_cost)
|
# if kwargs["stream"] != True:
|
||||||
# for non streaming responses
|
# response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
else:
|
# print(f"\n LiteLLM Proxy: regular response_cost: {response_cost} for api_key: {api_key}")
|
||||||
# we pass the completion_response obj
|
|
||||||
if kwargs["stream"] != True:
|
# ########### write costs to DB api_key / cost map
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
# asyncio.run(
|
||||||
print("regular response_cost", response_cost)
|
# update_verification_token_cost(token=api_key, additional_cost=response_cost)
|
||||||
except:
|
# )
|
||||||
pass
|
# except:
|
||||||
|
# pass
|
Loading…
Add table
Add a link
Reference in a new issue