fix(proxy_server.py): enable spend tracking for team-based jwt auth

This commit is contained in:
Krrish Dholakia 2024-03-28 20:16:07 -07:00
parent 792999d756
commit c15ba368e7
3 changed files with 25 additions and 5 deletions

View file

@ -293,7 +293,7 @@ litellm_proxy_admin_name = "default_user_id"
ui_access_mode: Literal["admin", "all"] = "all"
proxy_budget_rescheduler_min_time = 597
proxy_budget_rescheduler_max_time = 605
proxy_batch_write_at = 60 # in seconds
proxy_batch_write_at = 10 # in seconds
litellm_master_key_hash = None
disable_spend_logs = False
jwt_handler = JWTHandler()
@ -1161,7 +1161,7 @@ async def _PROXY_track_cost_callback(
verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
)
if user_api_key is not None:
if user_api_key is not None or user_id is not None or team_id is not None:
## UPDATE DATABASE
await update_database(
token=user_api_key,
@ -1182,7 +1182,9 @@ async def _PROXY_track_cost_callback(
response_cost=response_cost,
)
else:
raise Exception("User API key missing from custom callback.")
raise Exception(
"User API key and team id and user id missing from custom callback."
)
else:
if kwargs["stream"] != True or (
kwargs["stream"] == True and "complete_streaming_response" in kwargs
@ -1223,7 +1225,7 @@ async def update_database(
verbose_proxy_logger.info(
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
)
if isinstance(token, str) and token.startswith("sk-"):
if token is not None and isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
@ -1337,6 +1339,8 @@ async def update_database(
verbose_proxy_logger.debug(
f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}."
)
if hashed_token is None:
return
if prisma_client is not None:
prisma_client.key_list_transactons[hashed_token] = (
response_cost