feat(proxy_server.py): support cost tracking on user id via JWT-Auth

allows admin to track cost for LiteLLM_UserTable via JWT
This commit is contained in:
Krrish Dholakia 2024-04-11 18:47:46 -07:00
parent e413191493
commit 36ff593c02
4 changed files with 59 additions and 24 deletions

View file

@ -422,12 +422,21 @@ async def user_api_key_auth(
user_api_key_cache=user_api_key_cache,
)
# common checks
# allow request
# get the request body
request_data = await _read_request_body(request=request)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
if user_id is not None:
# get the user object
user_object = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the user object to cache
await user_api_key_cache.async_set_cache(
key=user_id, value=user_object
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
token=valid_token, default_value=None
@ -445,7 +454,6 @@ async def user_api_key_auth(
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
@ -480,16 +488,20 @@ async def user_api_key_auth(
)
)
# get the request body
request_data = await _read_request_body(request=request)
# run through common checks
_ = common_checks(
request_body=request_data,
team_object=team_object,
user_object=user_object,
end_user_object=end_user_object,
general_settings=general_settings,
global_proxy_spend=global_proxy_spend,
route=route,
)
# save user object in cache
# save team object in cache
await user_api_key_cache.async_set_cache(
key=team_object.team_id, value=team_object
)
@ -954,6 +966,7 @@ async def user_api_key_auth(
_ = common_checks(
request_body=request_data,
team_object=_team_obj,
user_object=None,
end_user_object=_end_user_object,
general_settings=general_settings,
global_proxy_spend=global_proxy_spend,