fix(handle_jwt.py): enable team-based jwt-auth access

Move auth to check on ‘client_id’ not ‘sub
This commit is contained in:
Krrish Dholakia 2024-03-26 12:25:38 -07:00
parent b4d0a95cff
commit 7d38c62717
4 changed files with 327 additions and 132 deletions

View file

@ -81,57 +81,27 @@ class JWTHandler:
return len(parts) == 3
def is_admin(self, scopes: list) -> bool:
if self.litellm_proxy_roles.proxy_admin in scopes:
if self.litellm_proxy_roles.admin_jwt_scope in scopes:
return True
return False
def get_user_id(self, token: dict, default_value: str) -> str:
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try:
user_id = token["sub"]
if self.litellm_proxy_roles.team_id_jwt_field is not None:
user_id = token[self.litellm_proxy_roles.team_id_jwt_field]
else:
user_id = None
except KeyError:
user_id = default_value
return user_id
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
team_id = token["client_id"]
team_id = token[self.litellm_proxy_roles.team_id_jwt_field]
except KeyError:
team_id = default_value
return team_id
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
def get_scopes(self, token: dict) -> list:
try:
if isinstance(token["scope"], str):