forked from phoenix/litellm-mirror
feat(proxy_server.py): support cost tracking on user id via JWT-Auth
allows admin to track cost for LiteLLM_UserTable via JWT
This commit is contained in:
parent
e413191493
commit
36ff593c02
4 changed files with 59 additions and 24 deletions
|
@ -139,7 +139,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
|||
team_allowed_routes: List[
|
||||
Literal["openai_routes", "info_routes", "management_routes"]
|
||||
] = ["openai_routes", "info_routes"]
|
||||
end_user_id_jwt_field: Optional[str] = "sub"
|
||||
user_id_jwt_field: Optional[str] = None
|
||||
end_user_id_jwt_field: Optional[str] = None
|
||||
public_key_ttl: float = 600
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
|
|
@ -26,6 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
|
|||
def common_checks(
|
||||
request_body: dict,
|
||||
team_object: LiteLLM_TeamTable,
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||
global_proxy_spend: Optional[float],
|
||||
general_settings: dict,
|
||||
|
@ -37,7 +38,8 @@ def common_checks(
|
|||
1. If team is blocked
|
||||
2. If team can call model
|
||||
3. If team is in budget
|
||||
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
5. If user passed in (JWT or key.user_id) - is in budget
|
||||
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
||||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
"""
|
||||
|
@ -69,14 +71,20 @@ def common_checks(
|
|||
raise Exception(
|
||||
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
||||
)
|
||||
# 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
if user_object is not None and user_object.max_budget is not None:
|
||||
user_budget = user_object.max_budget
|
||||
if user_budget > user_object.spend:
|
||||
raise Exception(
|
||||
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
|
||||
)
|
||||
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
||||
if end_user_budget is not None and end_user_object.spend > end_user_budget:
|
||||
raise Exception(
|
||||
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
||||
)
|
||||
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||
if (
|
||||
general_settings.get("enforce_user_param", None) is not None
|
||||
and general_settings["enforce_user_param"] == True
|
||||
|
@ -85,7 +93,7 @@ def common_checks(
|
|||
raise Exception(
|
||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||
)
|
||||
# 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
if litellm.max_budget > 0 and global_proxy_spend is not None:
|
||||
if global_proxy_spend > litellm.max_budget:
|
||||
raise Exception(
|
||||
|
@ -204,19 +212,24 @@ async def get_end_user_object(
|
|||
return None
|
||||
|
||||
|
||||
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||
async def get_user_object(
|
||||
user_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
- if valid, return LiteLLM_UserTable object with defined limits
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if self.prisma_client is None:
|
||||
raise Exception(
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
if prisma_client is None:
|
||||
raise Exception("No db connected")
|
||||
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
||||
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_UserTable(**cached_user_obj)
|
||||
|
@ -224,7 +237,7 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
|||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
|
@ -232,10 +245,8 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
|||
raise Exception
|
||||
|
||||
return LiteLLM_UserTable(**response.dict())
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
||||
)
|
||||
except Exception as e: # if end-user not in db
|
||||
return None
|
||||
|
||||
|
||||
async def get_team_object(
|
||||
|
|
|
@ -74,6 +74,16 @@ class JWTHandler:
|
|||
team_id = default_value
|
||||
return team_id
|
||||
|
||||
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
|
||||
else:
|
||||
user_id = None
|
||||
except KeyError:
|
||||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_scopes(self, token: dict) -> list:
|
||||
try:
|
||||
if isinstance(token["scope"], str):
|
||||
|
|
|
@ -422,12 +422,21 @@ async def user_api_key_auth(
|
|||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
# common checks
|
||||
# allow request
|
||||
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
|
||||
if user_id is not None:
|
||||
# get the user object
|
||||
user_object = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# save the user object to cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_id, value=user_object
|
||||
)
|
||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||
end_user_object = None
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=valid_token, default_value=None
|
||||
|
@ -445,7 +454,6 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
global_proxy_spend = None
|
||||
|
||||
if litellm.max_budget > 0: # user set proxy max budget
|
||||
# check cache
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
|
@ -480,16 +488,20 @@ async def user_api_key_auth(
|
|||
)
|
||||
)
|
||||
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
)
|
||||
# save user object in cache
|
||||
# save team object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=team_object.team_id, value=team_object
|
||||
)
|
||||
|
@ -954,6 +966,7 @@ async def user_api_key_auth(
|
|||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
user_object=None,
|
||||
end_user_object=_end_user_object,
|
||||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue