Merge pull request #2970 from BerriAI/litellm_keys

fix(handle_jwt.py): User cost tracking via JWT Auth
This commit is contained in:
Krish Dholakia 2024-04-11 21:44:15 -07:00 committed by GitHub
commit d89644d46c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 263 additions and 35 deletions

View file

@ -422,12 +422,21 @@ async def user_api_key_auth(
user_api_key_cache=user_api_key_cache,
)
# common checks
# allow request
# get the request body
request_data = await _read_request_body(request=request)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
if user_id is not None:
# get the user object
user_object = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the user object to cache
await user_api_key_cache.async_set_cache(
key=user_id, value=user_object
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
token=valid_token, default_value=None
@ -445,7 +454,6 @@ async def user_api_key_auth(
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
@ -480,16 +488,20 @@ async def user_api_key_auth(
)
)
# get the request body
request_data = await _read_request_body(request=request)
# run through common checks
_ = common_checks(
request_body=request_data,
team_object=team_object,
user_object=user_object,
end_user_object=end_user_object,
general_settings=general_settings,
global_proxy_spend=global_proxy_spend,
route=route,
)
# save user object in cache
# save team object in cache
await user_api_key_cache.async_set_cache(
key=team_object.team_id, value=team_object
)
@ -502,6 +514,7 @@ async def user_api_key_auth(
team_rpm_limit=team_object.rpm_limit,
team_models=team_object.models,
user_role="app_owner",
user_id=user_id,
)
#### ELSE ####
if master_key is None:
@ -954,6 +967,7 @@ async def user_api_key_auth(
_ = common_checks(
request_body=request_data,
team_object=_team_obj,
user_object=None,
end_user_object=_end_user_object,
general_settings=general_settings,
global_proxy_spend=global_proxy_spend,
@ -1328,8 +1342,6 @@ async def update_database(
existing_token_obj = await user_api_key_cache.async_get_cache(
key=hashed_token
)
if existing_token_obj is None:
return
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
@ -1351,7 +1363,9 @@ async def update_database(
if end_user_id is not None:
prisma_client.end_user_list_transactons[end_user_id] = (
response_cost
+ prisma_client.user_list_transactons.get(end_user_id, 0)
+ prisma_client.end_user_list_transactons.get(
end_user_id, 0
)
)
elif custom_db_client is not None:
for id in user_ids: