forked from phoenix/litellm-mirror
Merge pull request #2970 from BerriAI/litellm_keys
fix(handle_jwt.py): User cost tracking via JWT Auth
This commit is contained in:
commit
d89644d46c
7 changed files with 263 additions and 35 deletions
|
@ -422,12 +422,21 @@ async def user_api_key_auth(
|
|||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
# common checks
|
||||
# allow request
|
||||
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
|
||||
if user_id is not None:
|
||||
# get the user object
|
||||
user_object = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# save the user object to cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_id, value=user_object
|
||||
)
|
||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||
end_user_object = None
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=valid_token, default_value=None
|
||||
|
@ -445,7 +454,6 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
global_proxy_spend = None
|
||||
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
|
@ -480,16 +488,20 @@ async def user_api_key_auth(
|
|||
)
|
||||
)
|
||||
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
)
|
||||
# save user object in cache
|
||||
# save team object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=team_object.team_id, value=team_object
|
||||
)
|
||||
|
@ -502,6 +514,7 @@ async def user_api_key_auth(
|
|||
team_rpm_limit=team_object.rpm_limit,
|
||||
team_models=team_object.models,
|
||||
user_role="app_owner",
|
||||
user_id=user_id,
|
||||
)
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
|
@ -954,6 +967,7 @@ async def user_api_key_auth(
|
|||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
user_object=None,
|
||||
end_user_object=_end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
|
@ -1328,8 +1342,6 @@ async def update_database(
|
|||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||
key=hashed_token
|
||||
)
|
||||
if existing_token_obj is None:
|
||||
return
|
||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
||||
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
||||
|
@ -1351,7 +1363,9 @@ async def update_database(
|
|||
if end_user_id is not None:
|
||||
prisma_client.end_user_list_transactons[end_user_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(end_user_id, 0)
|
||||
+ prisma_client.end_user_list_transactons.get(
|
||||
end_user_id, 0
|
||||
)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
for id in user_ids:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue