mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(handle_jwt.py): enable team-based jwt-auth access
Move auth to check on ‘client_id’ not ‘sub
This commit is contained in:
parent
b4d0a95cff
commit
7d38c62717
4 changed files with 327 additions and 132 deletions
|
@ -1,4 +1,5 @@
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
|
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
|
||||||
|
from dataclasses import fields
|
||||||
import enum
|
import enum
|
||||||
from typing import Optional, List, Union, Dict, Literal, Any
|
from typing import Optional, List, Union, Dict, Literal, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -37,9 +38,96 @@ class LiteLLMBase(BaseModel):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMRoutes(enum.Enum):
|
||||||
|
openai_routes: List = [ # chat completions
|
||||||
|
"/openai/deployments/{model}/chat/completions",
|
||||||
|
"/chat/completions",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
# completions
|
||||||
|
"/openai/deployments/{model}/completions",
|
||||||
|
"/completions",
|
||||||
|
"/v1/completions",
|
||||||
|
# embeddings
|
||||||
|
"/openai/deployments/{model}/embeddings",
|
||||||
|
"/embeddings",
|
||||||
|
"/v1/embeddings",
|
||||||
|
# image generation
|
||||||
|
"/images/generations",
|
||||||
|
"/v1/images/generations",
|
||||||
|
# audio transcription
|
||||||
|
"/audio/transcriptions",
|
||||||
|
"/v1/audio/transcriptions",
|
||||||
|
# moderations
|
||||||
|
"/moderations",
|
||||||
|
"/v1/moderations",
|
||||||
|
# models
|
||||||
|
"/models",
|
||||||
|
"/v1/models",
|
||||||
|
]
|
||||||
|
|
||||||
|
info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"]
|
||||||
|
|
||||||
|
management_routes: List = [ # key
|
||||||
|
"/key/generate",
|
||||||
|
"/key/update",
|
||||||
|
"/key/delete",
|
||||||
|
"/key/info",
|
||||||
|
# user
|
||||||
|
"/user/new",
|
||||||
|
"/user/update",
|
||||||
|
"/user/delete",
|
||||||
|
"/user/info",
|
||||||
|
# team
|
||||||
|
"/team/new",
|
||||||
|
"/team/update",
|
||||||
|
"/team/delete",
|
||||||
|
"/team/info",
|
||||||
|
# model
|
||||||
|
"/model/new",
|
||||||
|
"/model/update",
|
||||||
|
"/model/delete",
|
||||||
|
"/model/info",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProxyRoles(LiteLLMBase):
|
class LiteLLMProxyRoles(LiteLLMBase):
|
||||||
proxy_admin: str = "litellm_proxy_admin"
|
"""
|
||||||
proxy_user: str = "litellm_user"
|
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- admin_jwt_scope: The JWT scope required for proxy admin roles.
|
||||||
|
- admin_allowed_routes: list of allowed routes for proxy admin roles.
|
||||||
|
- team_jwt_scope: The JWT scope required for proxy team roles.
|
||||||
|
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
||||||
|
- team_allowed_routes: list of allowed routes for proxy team roles.
|
||||||
|
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
|
||||||
|
|
||||||
|
See `auth_checks.py` for the specific routes
|
||||||
|
"""
|
||||||
|
|
||||||
|
admin_jwt_scope: str = "litellm_proxy_admin"
|
||||||
|
admin_allowed_routes: List[
|
||||||
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
|
] = ["management_routes"]
|
||||||
|
team_jwt_scope: str = "litellm_team"
|
||||||
|
team_id_jwt_field: str = "client_id"
|
||||||
|
team_allowed_routes: List[
|
||||||
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
|
] = ["openai_routes", "info_routes"]
|
||||||
|
end_user_id_jwt_field: Optional[str] = "sub"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
# get the attribute names for this Pydantic model
|
||||||
|
allowed_keys = self.__annotations__.keys()
|
||||||
|
|
||||||
|
invalid_keys = set(kwargs.keys()) - allowed_keys
|
||||||
|
|
||||||
|
if invalid_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMPromptInjectionParams(LiteLLMBase):
|
class LiteLLMPromptInjectionParams(LiteLLMBase):
|
||||||
|
|
|
@ -8,15 +8,23 @@ Run checks for:
|
||||||
2. If user is in budget
|
2. If user is in budget
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
|
from litellm.proxy._types import (
|
||||||
|
LiteLLM_UserTable,
|
||||||
|
LiteLLM_EndUserTable,
|
||||||
|
LiteLLMProxyRoles,
|
||||||
|
LiteLLM_TeamTable,
|
||||||
|
LiteLLMRoutes,
|
||||||
|
)
|
||||||
from typing import Optional, Literal
|
from typing import Optional, Literal
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||||
|
|
||||||
|
|
||||||
def common_checks(
|
def common_checks(
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
user_object: LiteLLM_UserTable,
|
team_object: LiteLLM_TeamTable,
|
||||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -30,19 +38,20 @@ def common_checks(
|
||||||
# 1. If user can call model
|
# 1. If user can call model
|
||||||
if (
|
if (
|
||||||
_model is not None
|
_model is not None
|
||||||
and len(user_object.models) > 0
|
and len(team_object.models) > 0
|
||||||
and _model not in user_object.models
|
and _model not in team_object.models
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}"
|
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
|
||||||
)
|
)
|
||||||
# 2. If user is in budget
|
# 2. If team is in budget
|
||||||
if (
|
if (
|
||||||
user_object.max_budget is not None
|
team_object.max_budget is not None
|
||||||
and user_object.spend > user_object.max_budget
|
and team_object.spend is not None
|
||||||
|
and team_object.spend > team_object.max_budget
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}"
|
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
||||||
)
|
)
|
||||||
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||||
|
@ -54,52 +63,79 @@ def common_checks(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
|
for allowed_route in allowed_routes:
|
||||||
|
if (
|
||||||
|
allowed_route == LiteLLMRoutes.openai_routes.name
|
||||||
|
and user_route in LiteLLMRoutes.openai_routes.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
elif (
|
||||||
|
allowed_route == LiteLLMRoutes.info_routes.name
|
||||||
|
and user_route in LiteLLMRoutes.info_routes.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
elif (
|
||||||
|
allowed_route == LiteLLMRoutes.management_routes.name
|
||||||
|
and user_route in LiteLLMRoutes.management_routes.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
elif allowed_route == user_route:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def allowed_routes_check(
|
def allowed_routes_check(
|
||||||
user_role: Literal["proxy_admin", "app_owner"],
|
user_role: Literal["proxy_admin", "team"],
|
||||||
route: str,
|
user_route: str,
|
||||||
allowed_routes: Optional[list] = None,
|
litellm_proxy_roles: LiteLLMProxyRoles,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if user -> not admin - allowed to access these routes
|
Check if user -> not admin - allowed to access these routes
|
||||||
"""
|
"""
|
||||||
openai_routes = [
|
|
||||||
# chat completions
|
|
||||||
"/openai/deployments/{model}/chat/completions",
|
|
||||||
"/chat/completions",
|
|
||||||
"/v1/chat/completions",
|
|
||||||
# completions
|
|
||||||
# embeddings
|
|
||||||
"/openai/deployments/{model}/embeddings",
|
|
||||||
"/embeddings",
|
|
||||||
"/v1/embeddings",
|
|
||||||
# image generation
|
|
||||||
"/images/generations",
|
|
||||||
"/v1/images/generations",
|
|
||||||
# audio transcription
|
|
||||||
"/audio/transcriptions",
|
|
||||||
"/v1/audio/transcriptions",
|
|
||||||
# moderations
|
|
||||||
"/moderations",
|
|
||||||
"/v1/moderations",
|
|
||||||
# models
|
|
||||||
"/models",
|
|
||||||
"/v1/models",
|
|
||||||
]
|
|
||||||
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
|
|
||||||
default_routes = openai_routes + info_routes
|
|
||||||
if user_role == "proxy_admin":
|
if user_role == "proxy_admin":
|
||||||
return True
|
if litellm_proxy_roles.admin_allowed_routes is None:
|
||||||
elif user_role == "app_owner":
|
is_allowed = _allowed_routes_check(
|
||||||
if allowed_routes is None:
|
user_route=user_route, allowed_routes=["management_routes"]
|
||||||
if route in default_routes: # check default routes
|
)
|
||||||
return True
|
return is_allowed
|
||||||
elif route in allowed_routes:
|
elif litellm_proxy_roles.admin_allowed_routes is not None:
|
||||||
return True
|
is_allowed = _allowed_routes_check(
|
||||||
else:
|
user_route=user_route,
|
||||||
return False
|
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
|
||||||
|
)
|
||||||
|
return is_allowed
|
||||||
|
|
||||||
|
elif user_role == "team":
|
||||||
|
if litellm_proxy_roles.team_allowed_routes is None:
|
||||||
|
"""
|
||||||
|
By default allow a team to call openai + info routes
|
||||||
|
"""
|
||||||
|
is_allowed = _allowed_routes_check(
|
||||||
|
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
|
||||||
|
)
|
||||||
|
return is_allowed
|
||||||
|
elif litellm_proxy_roles.team_allowed_routes is not None:
|
||||||
|
is_allowed = _allowed_routes_check(
|
||||||
|
user_route=user_route,
|
||||||
|
allowed_routes=litellm_proxy_roles.team_allowed_routes,
|
||||||
|
)
|
||||||
|
return is_allowed
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_actual_routes(allowed_routes: list) -> list:
|
||||||
|
actual_routes: list = []
|
||||||
|
for route_name in allowed_routes:
|
||||||
|
try:
|
||||||
|
route_value = LiteLLMRoutes[route_name].value
|
||||||
|
actual_routes = actual_routes + route_value
|
||||||
|
except KeyError:
|
||||||
|
actual_routes.append(route_name)
|
||||||
|
return actual_routes
|
||||||
|
|
||||||
|
|
||||||
async def get_end_user_object(
|
async def get_end_user_object(
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -135,3 +171,75 @@ async def get_end_user_object(
|
||||||
return LiteLLM_EndUserTable(**response.dict())
|
return LiteLLM_EndUserTable(**response.dict())
|
||||||
except Exception as e: # if end-user not in db
|
except Exception as e: # if end-user not in db
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||||
|
"""
|
||||||
|
- Check if user id in proxy User Table
|
||||||
|
- if valid, return LiteLLM_UserTable object with defined limits
|
||||||
|
- if not, then raise an error
|
||||||
|
"""
|
||||||
|
if self.prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if in cache
|
||||||
|
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
||||||
|
if cached_user_obj is not None:
|
||||||
|
if isinstance(cached_user_obj, dict):
|
||||||
|
return LiteLLM_UserTable(**cached_user_obj)
|
||||||
|
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
||||||
|
return cached_user_obj
|
||||||
|
# else, check db
|
||||||
|
try:
|
||||||
|
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
||||||
|
where={"user_id": user_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
return LiteLLM_UserTable(**response.dict())
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_team_object(
|
||||||
|
team_id: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
) -> LiteLLM_TeamTable:
|
||||||
|
"""
|
||||||
|
- Check if team id in proxy Team Table
|
||||||
|
- if valid, return LiteLLM_TeamTable object with defined limits
|
||||||
|
- if not, then raise an error
|
||||||
|
"""
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if in cache
|
||||||
|
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
|
||||||
|
if cached_team_obj is not None:
|
||||||
|
if isinstance(cached_team_obj, dict):
|
||||||
|
return LiteLLM_TeamTable(**cached_team_obj)
|
||||||
|
elif isinstance(cached_team_obj, LiteLLM_TeamTable):
|
||||||
|
return cached_team_obj
|
||||||
|
# else, check db
|
||||||
|
try:
|
||||||
|
response = await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
|
where={"team_id": team_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
return LiteLLM_TeamTable(**response.dict())
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||||
|
)
|
||||||
|
|
|
@ -81,57 +81,27 @@ class JWTHandler:
|
||||||
return len(parts) == 3
|
return len(parts) == 3
|
||||||
|
|
||||||
def is_admin(self, scopes: list) -> bool:
|
def is_admin(self, scopes: list) -> bool:
|
||||||
if self.litellm_proxy_roles.proxy_admin in scopes:
|
if self.litellm_proxy_roles.admin_jwt_scope in scopes:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_user_id(self, token: dict, default_value: str) -> str:
|
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
||||||
try:
|
try:
|
||||||
user_id = token["sub"]
|
if self.litellm_proxy_roles.team_id_jwt_field is not None:
|
||||||
|
user_id = token[self.litellm_proxy_roles.team_id_jwt_field]
|
||||||
|
else:
|
||||||
|
user_id = None
|
||||||
except KeyError:
|
except KeyError:
|
||||||
user_id = default_value
|
user_id = default_value
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
team_id = token["client_id"]
|
team_id = token[self.litellm_proxy_roles.team_id_jwt_field]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
|
||||||
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
|
||||||
"""
|
|
||||||
- Check if user id in proxy User Table
|
|
||||||
- if valid, return LiteLLM_UserTable object with defined limits
|
|
||||||
- if not, then raise an error
|
|
||||||
"""
|
|
||||||
if self.prisma_client is None:
|
|
||||||
raise Exception(
|
|
||||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if in cache
|
|
||||||
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
|
||||||
if cached_user_obj is not None:
|
|
||||||
if isinstance(cached_user_obj, dict):
|
|
||||||
return LiteLLM_UserTable(**cached_user_obj)
|
|
||||||
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
|
||||||
return cached_user_obj
|
|
||||||
# else, check db
|
|
||||||
try:
|
|
||||||
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
|
||||||
where={"user_id": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response is None:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
return LiteLLM_UserTable(**response.dict())
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(
|
|
||||||
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_scopes(self, token: dict) -> list:
|
def get_scopes(self, token: dict) -> list:
|
||||||
try:
|
try:
|
||||||
if isinstance(token["scope"], str):
|
if isinstance(token["scope"], str):
|
||||||
|
|
|
@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
common_checks,
|
common_checks,
|
||||||
get_end_user_object,
|
get_end_user_object,
|
||||||
|
get_team_object,
|
||||||
|
get_user_object,
|
||||||
allowed_routes_check,
|
allowed_routes_check,
|
||||||
|
get_actual_routes,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -369,71 +372,93 @@ async def user_api_key_auth(
|
||||||
scopes = jwt_handler.get_scopes(token=valid_token)
|
scopes = jwt_handler.get_scopes(token=valid_token)
|
||||||
# check if admin
|
# check if admin
|
||||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||||
# get user id
|
# if admin return
|
||||||
user_id = jwt_handler.get_user_id(
|
if is_admin:
|
||||||
token=valid_token, default_value=litellm_proxy_admin_name
|
# check allowed admin routes
|
||||||
|
is_allowed = allowed_routes_check(
|
||||||
|
user_role="proxy_admin",
|
||||||
|
user_route=route,
|
||||||
|
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
|
||||||
|
)
|
||||||
|
if is_allowed:
|
||||||
|
return UserAPIKeyAuth()
|
||||||
|
else:
|
||||||
|
allowed_routes = (
|
||||||
|
jwt_handler.litellm_proxy_roles.admin_allowed_routes
|
||||||
|
)
|
||||||
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
|
raise Exception(
|
||||||
|
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
)
|
||||||
|
# get team id
|
||||||
|
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
|
||||||
|
|
||||||
|
if team_id is None:
|
||||||
|
raise Exception(
|
||||||
|
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_proxy_roles.team_id_jwt_field}'"
|
||||||
|
)
|
||||||
|
# check allowed team routes
|
||||||
|
is_allowed = allowed_routes_check(
|
||||||
|
user_role="team",
|
||||||
|
user_route=route,
|
||||||
|
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
|
||||||
|
)
|
||||||
|
if is_allowed == False:
|
||||||
|
allowed_routes = jwt_handler.litellm_proxy_roles.team_allowed_routes
|
||||||
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
|
raise Exception(
|
||||||
|
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if team in db
|
||||||
|
team_object = await get_team_object(
|
||||||
|
team_id=team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
end_user_object = None
|
# common checks
|
||||||
|
# allow request
|
||||||
|
|
||||||
# get the request body
|
# get the request body
|
||||||
request_data = await _read_request_body(request=request)
|
request_data = await _read_request_body(request=request)
|
||||||
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
|
|
||||||
user_object = await jwt_handler.get_user_object(user_id=user_id)
|
end_user_object = None
|
||||||
if (
|
end_user_id = jwt_handler.get_end_user_id(
|
||||||
request_data.get("user", None)
|
token=valid_token, default_value=None
|
||||||
and request_data["user"] != user_object.user_id
|
)
|
||||||
):
|
if end_user_id is not None:
|
||||||
# get the end-user object
|
# get the end-user object
|
||||||
end_user_object = await get_end_user_object(
|
end_user_object = await get_end_user_object(
|
||||||
end_user_id=request_data["user"],
|
end_user_id=end_user_id,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
)
|
)
|
||||||
# save the end-user object to cache
|
# save the end-user object to cache
|
||||||
await user_api_key_cache.async_set_cache(
|
await user_api_key_cache.async_set_cache(
|
||||||
key=request_data["user"], value=end_user_object
|
key=end_user_id, value=end_user_object
|
||||||
)
|
)
|
||||||
|
|
||||||
# run through common checks
|
# run through common checks
|
||||||
_ = common_checks(
|
_ = common_checks(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
user_object=user_object,
|
team_object=team_object,
|
||||||
end_user_object=end_user_object,
|
end_user_object=end_user_object,
|
||||||
)
|
)
|
||||||
# save user object in cache
|
# save user object in cache
|
||||||
await user_api_key_cache.async_set_cache(
|
await user_api_key_cache.async_set_cache(
|
||||||
key=user_object.user_id, value=user_object
|
key=team_object.team_id, value=team_object
|
||||||
|
)
|
||||||
|
|
||||||
|
# return UserAPIKeyAuth object
|
||||||
|
return UserAPIKeyAuth(
|
||||||
|
api_key=None,
|
||||||
|
team_id=team_object.team_id,
|
||||||
|
tpm_limit=team_object.tpm_limit,
|
||||||
|
rpm_limit=team_object.rpm_limit,
|
||||||
|
models=team_object.models,
|
||||||
|
user_role="app_owner",
|
||||||
)
|
)
|
||||||
# if admin return
|
|
||||||
if is_admin:
|
|
||||||
return UserAPIKeyAuth(
|
|
||||||
api_key=api_key,
|
|
||||||
user_role="proxy_admin",
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
is_allowed = allowed_routes_check(
|
|
||||||
user_role="app_owner",
|
|
||||||
route=route,
|
|
||||||
allowed_routes=general_settings.get("allowed_routes", None),
|
|
||||||
)
|
|
||||||
if is_allowed:
|
|
||||||
# return UserAPIKeyAuth object
|
|
||||||
return UserAPIKeyAuth(
|
|
||||||
api_key=None,
|
|
||||||
user_id=user_object.user_id,
|
|
||||||
tpm_limit=user_object.tpm_limit,
|
|
||||||
rpm_limit=user_object.rpm_limit,
|
|
||||||
models=user_object.models,
|
|
||||||
user_role="app_owner",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail={
|
|
||||||
"error": f"User={user_object.user_id} not allowed to access this route={route}."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
|
@ -2709,12 +2734,16 @@ async def startup_event():
|
||||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||||
|
|
||||||
## JWT AUTH ##
|
## JWT AUTH ##
|
||||||
|
if general_settings.get("litellm_proxy_roles", None) is not None:
|
||||||
|
litellm_proxy_roles = LiteLLMProxyRoles(
|
||||||
|
**general_settings["litellm_proxy_roles"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
litellm_proxy_roles = LiteLLMProxyRoles()
|
||||||
jwt_handler.update_environment(
|
jwt_handler.update_environment(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
litellm_proxy_roles=LiteLLMProxyRoles(
|
litellm_proxy_roles=litellm_proxy_roles,
|
||||||
**general_settings.get("litellm_proxy_roles", {})
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_background_health_checks:
|
if use_background_health_checks:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue