fix(handle_jwt.py): enable team-based jwt-auth access

Move auth to check on ‘client_id’ not ‘sub
This commit is contained in:
Krrish Dholakia 2024-03-26 12:25:38 -07:00
parent b4d0a95cff
commit 7d38c62717
4 changed files with 327 additions and 132 deletions

View file

@ -1,4 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
from dataclasses import fields
import enum import enum
from typing import Optional, List, Union, Dict, Literal, Any from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime from datetime import datetime
@ -37,9 +38,96 @@ class LiteLLMBase(BaseModel):
protected_namespaces = () protected_namespaces = ()
class LiteLLMRoutes(enum.Enum):
openai_routes: List = [ # chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
"/openai/deployments/{model}/completions",
"/completions",
"/v1/completions",
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"]
management_routes: List = [ # key
"/key/generate",
"/key/update",
"/key/delete",
"/key/info",
# user
"/user/new",
"/user/update",
"/user/delete",
"/user/info",
# team
"/team/new",
"/team/update",
"/team/delete",
"/team/info",
# model
"/model/new",
"/model/update",
"/model/delete",
"/model/info",
]
class LiteLLMProxyRoles(LiteLLMBase): class LiteLLMProxyRoles(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin" """
proxy_user: str = "litellm_user" A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
Attributes:
- admin_jwt_scope: The JWT scope required for proxy admin roles.
- admin_allowed_routes: list of allowed routes for proxy admin roles.
- team_jwt_scope: The JWT scope required for proxy team roles.
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
See `auth_checks.py` for the specific routes
"""
admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["management_routes"]
team_jwt_scope: str = "litellm_team"
team_id_jwt_field: str = "client_id"
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
end_user_id_jwt_field: Optional[str] = "sub"
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
super().__init__(**kwargs)
class LiteLLMPromptInjectionParams(LiteLLMBase): class LiteLLMPromptInjectionParams(LiteLLMBase):

View file

@ -8,15 +8,23 @@ Run checks for:
2. If user is in budget 2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable from litellm.proxy._types import (
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLMProxyRoles,
LiteLLM_TeamTable,
LiteLLMRoutes,
)
from typing import Optional, Literal from typing import Optional, Literal
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache from litellm.caching import DualCache
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def common_checks( def common_checks(
request_body: dict, request_body: dict,
user_object: LiteLLM_UserTable, team_object: LiteLLM_TeamTable,
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
) -> bool: ) -> bool:
""" """
@ -30,19 +38,20 @@ def common_checks(
# 1. If user can call model # 1. If user can call model
if ( if (
_model is not None _model is not None
and len(user_object.models) > 0 and len(team_object.models) > 0
and _model not in user_object.models and _model not in team_object.models
): ):
raise Exception( raise Exception(
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}" f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
) )
# 2. If user is in budget # 2. If team is in budget
if ( if (
user_object.max_budget is not None team_object.max_budget is not None
and user_object.spend > user_object.max_budget and team_object.spend is not None
and team_object.spend > team_object.max_budget
): ):
raise Exception( raise Exception(
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}" f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
) )
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget # 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None: if end_user_object is not None and end_user_object.litellm_budget_table is not None:
@ -54,52 +63,79 @@ def common_checks(
return True return True
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
for allowed_route in allowed_routes:
if (
allowed_route == LiteLLMRoutes.openai_routes.name
and user_route in LiteLLMRoutes.openai_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.info_routes.name
and user_route in LiteLLMRoutes.info_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.management_routes.name
and user_route in LiteLLMRoutes.management_routes.value
):
return True
elif allowed_route == user_route:
return True
return False
def allowed_routes_check( def allowed_routes_check(
user_role: Literal["proxy_admin", "app_owner"], user_role: Literal["proxy_admin", "team"],
route: str, user_route: str,
allowed_routes: Optional[list] = None, litellm_proxy_roles: LiteLLMProxyRoles,
) -> bool: ) -> bool:
""" """
Check if user -> not admin - allowed to access these routes Check if user -> not admin - allowed to access these routes
""" """
openai_routes = [
# chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
default_routes = openai_routes + info_routes
if user_role == "proxy_admin": if user_role == "proxy_admin":
return True if litellm_proxy_roles.admin_allowed_routes is None:
elif user_role == "app_owner": is_allowed = _allowed_routes_check(
if allowed_routes is None: user_route=user_route, allowed_routes=["management_routes"]
if route in default_routes: # check default routes )
return True return is_allowed
elif route in allowed_routes: elif litellm_proxy_roles.admin_allowed_routes is not None:
return True is_allowed = _allowed_routes_check(
else: user_route=user_route,
return False allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == "team":
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes
"""
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
)
return is_allowed
elif litellm_proxy_roles.team_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.team_allowed_routes,
)
return is_allowed
return False return False
def get_actual_routes(allowed_routes: list) -> list:
actual_routes: list = []
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
actual_routes = actual_routes + route_value
except KeyError:
actual_routes.append(route_name)
return actual_routes
async def get_end_user_object( async def get_end_user_object(
end_user_id: Optional[str], end_user_id: Optional[str],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
@ -135,3 +171,75 @@ async def get_end_user_object(
return LiteLLM_EndUserTable(**response.dict()) return LiteLLM_EndUserTable(**response.dict())
except Exception as e: # if end-user not in db except Exception as e: # if end-user not in db
return None return None
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
) -> LiteLLM_TeamTable:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTable):
return cached_team_obj
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if response is None:
raise Exception
return LiteLLM_TeamTable(**response.dict())
except Exception as e:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
)

View file

@ -81,57 +81,27 @@ class JWTHandler:
return len(parts) == 3 return len(parts) == 3
def is_admin(self, scopes: list) -> bool: def is_admin(self, scopes: list) -> bool:
if self.litellm_proxy_roles.proxy_admin in scopes: if self.litellm_proxy_roles.admin_jwt_scope in scopes:
return True return True
return False return False
def get_user_id(self, token: dict, default_value: str) -> str: def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try: try:
user_id = token["sub"] if self.litellm_proxy_roles.team_id_jwt_field is not None:
user_id = token[self.litellm_proxy_roles.team_id_jwt_field]
else:
user_id = None
except KeyError: except KeyError:
user_id = default_value user_id = default_value
return user_id return user_id
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
team_id = token["client_id"] team_id = token[self.litellm_proxy_roles.team_id_jwt_field]
except KeyError: except KeyError:
team_id = default_value team_id = default_value
return team_id return team_id
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
def get_scopes(self, token: dict) -> list: def get_scopes(self, token: dict) -> list:
try: try:
if isinstance(token["scope"], str): if isinstance(token["scope"], str):

View file

@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import (
common_checks, common_checks,
get_end_user_object, get_end_user_object,
get_team_object,
get_user_object,
allowed_routes_check, allowed_routes_check,
get_actual_routes,
) )
try: try:
@ -369,71 +372,93 @@ async def user_api_key_auth(
scopes = jwt_handler.get_scopes(token=valid_token) scopes = jwt_handler.get_scopes(token=valid_token)
# check if admin # check if admin
is_admin = jwt_handler.is_admin(scopes=scopes) is_admin = jwt_handler.is_admin(scopes=scopes)
# get user id # if admin return
user_id = jwt_handler.get_user_id( if is_admin:
token=valid_token, default_value=litellm_proxy_admin_name # check allowed admin routes
is_allowed = allowed_routes_check(
user_role="proxy_admin",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
)
if is_allowed:
return UserAPIKeyAuth()
else:
allowed_routes = (
jwt_handler.litellm_proxy_roles.admin_allowed_routes
)
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# get team id
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
if team_id is None:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_proxy_roles.team_id_jwt_field}'"
)
# check allowed team routes
is_allowed = allowed_routes_check(
user_role="team",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
)
if is_allowed == False:
allowed_routes = jwt_handler.litellm_proxy_roles.team_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
) )
end_user_object = None # check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# common checks
# allow request
# get the request body # get the request body
request_data = await _read_request_body(request=request) request_data = await _read_request_body(request=request)
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
user_object = await jwt_handler.get_user_object(user_id=user_id) end_user_object = None
if ( end_user_id = jwt_handler.get_end_user_id(
request_data.get("user", None) token=valid_token, default_value=None
and request_data["user"] != user_object.user_id )
): if end_user_id is not None:
# get the end-user object # get the end-user object
end_user_object = await get_end_user_object( end_user_object = await get_end_user_object(
end_user_id=request_data["user"], end_user_id=end_user_id,
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
) )
# save the end-user object to cache # save the end-user object to cache
await user_api_key_cache.async_set_cache( await user_api_key_cache.async_set_cache(
key=request_data["user"], value=end_user_object key=end_user_id, value=end_user_object
) )
# run through common checks # run through common checks
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
user_object=user_object, team_object=team_object,
end_user_object=end_user_object, end_user_object=end_user_object,
) )
# save user object in cache # save user object in cache
await user_api_key_cache.async_set_cache( await user_api_key_cache.async_set_cache(
key=user_object.user_id, value=user_object key=team_object.team_id, value=team_object
) )
# if admin return
if is_admin:
return UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
else:
is_allowed = allowed_routes_check(
user_role="app_owner",
route=route,
allowed_routes=general_settings.get("allowed_routes", None),
)
if is_allowed:
# return UserAPIKeyAuth object # return UserAPIKeyAuth object
return UserAPIKeyAuth( return UserAPIKeyAuth(
api_key=None, api_key=None,
user_id=user_object.user_id, team_id=team_object.team_id,
tpm_limit=user_object.tpm_limit, tpm_limit=team_object.tpm_limit,
rpm_limit=user_object.rpm_limit, rpm_limit=team_object.rpm_limit,
models=user_object.models, models=team_object.models,
user_role="app_owner", user_role="app_owner",
) )
else:
raise HTTPException(
status_code=401,
detail={
"error": f"User={user_object.user_id} not allowed to access this route={route}."
},
)
#### ELSE #### #### ELSE ####
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
@ -2709,12 +2734,16 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ## ## JWT AUTH ##
if general_settings.get("litellm_proxy_roles", None) is not None:
litellm_proxy_roles = LiteLLMProxyRoles(
**general_settings["litellm_proxy_roles"]
)
else:
litellm_proxy_roles = LiteLLMProxyRoles()
jwt_handler.update_environment( jwt_handler.update_environment(
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
litellm_proxy_roles=LiteLLMProxyRoles( litellm_proxy_roles=litellm_proxy_roles,
**general_settings.get("litellm_proxy_roles", {})
),
) )
if use_background_health_checks: if use_background_health_checks: