feat(router.py): enable filtering model group by 'allowed_model_region'

This commit is contained in:
Krrish Dholakia 2024-05-08 22:10:17 -07:00
parent db666b01e5
commit 3d18897d69
11 changed files with 417 additions and 35 deletions

View file

@ -472,10 +472,6 @@ async def user_api_key_auth(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=end_user_id, value=end_user_object
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
@ -957,13 +953,16 @@ async def user_api_key_auth(
_end_user_object = None
if "user" in request_data:
_id = "end_user_id:{}".format(request_data["user"])
_end_user_object = await user_api_key_cache.async_get_cache(key=_id)
if _end_user_object is not None:
_end_user_object = LiteLLM_EndUserTable(**_end_user_object)
_end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
if (
litellm.max_budget > 0 and prisma_client is not None
): # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
@ -1016,6 +1015,12 @@ async def user_api_key_auth(
)
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
if _end_user_object is not None:
valid_token_dict["allowed_model_region"] = (
_end_user_object.allowed_model_region
)
"""
asyncio create task to update the user api key cache with the user db table as well
@ -1040,10 +1045,7 @@ async def user_api_key_auth(
# check if user can access this route
query_params = request.query_params
key = query_params.get("key")
if (
key is not None
and prisma_client.hash_token(token=key) != api_key
):
if key is not None and hash_token(token=key) != api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info",
@ -1096,6 +1098,7 @@ async def user_api_key_auth(
# sso/login, ui/login, /key functions and /user functions
# this will never be allowed to call /chat/completions
token_team = getattr(valid_token, "team_id", None)
if token_team is not None and token_team == "litellm-dashboard":
# this token is only used for managing the ui
allowed_routes = [
@ -3617,6 +3620,10 @@ async def chat_completion(
**data,
} # add the team-specific configs to the completion call
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
if user_temperature: