fix(handle_jwt.py): enable team-based jwt-auth access

Move auth to check on ‘client_id’ not ‘sub
This commit is contained in:
Krrish Dholakia 2024-03-26 12:25:38 -07:00
parent b4d0a95cff
commit 7d38c62717
4 changed files with 327 additions and 132 deletions

View file

@ -1,4 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime
@ -37,9 +38,96 @@ class LiteLLMBase(BaseModel):
protected_namespaces = ()
class LiteLLMRoutes(enum.Enum):
openai_routes: List = [ # chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
"/openai/deployments/{model}/completions",
"/completions",
"/v1/completions",
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"]
management_routes: List = [ # key
"/key/generate",
"/key/update",
"/key/delete",
"/key/info",
# user
"/user/new",
"/user/update",
"/user/delete",
"/user/info",
# team
"/team/new",
"/team/update",
"/team/delete",
"/team/info",
# model
"/model/new",
"/model/update",
"/model/delete",
"/model/info",
]
class LiteLLMProxyRoles(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user"
"""
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
Attributes:
- admin_jwt_scope: The JWT scope required for proxy admin roles.
- admin_allowed_routes: list of allowed routes for proxy admin roles.
- team_jwt_scope: The JWT scope required for proxy team roles.
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
See `auth_checks.py` for the specific routes
"""
admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["management_routes"]
team_jwt_scope: str = "litellm_team"
team_id_jwt_field: str = "client_id"
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
end_user_id_jwt_field: Optional[str] = "sub"
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
super().__init__(**kwargs)
class LiteLLMPromptInjectionParams(LiteLLMBase):

View file

@ -8,15 +8,23 @@ Run checks for:
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
from litellm.proxy._types import (
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLMProxyRoles,
LiteLLM_TeamTable,
LiteLLMRoutes,
)
from typing import Optional, Literal
from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def common_checks(
request_body: dict,
user_object: LiteLLM_UserTable,
team_object: LiteLLM_TeamTable,
end_user_object: Optional[LiteLLM_EndUserTable],
) -> bool:
"""
@ -30,19 +38,20 @@ def common_checks(
# 1. If user can call model
if (
_model is not None
and len(user_object.models) > 0
and _model not in user_object.models
and len(team_object.models) > 0
and _model not in team_object.models
):
raise Exception(
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}"
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
)
# 2. If user is in budget
# 2. If team is in budget
if (
user_object.max_budget is not None
and user_object.spend > user_object.max_budget
team_object.max_budget is not None
and team_object.spend is not None
and team_object.spend > team_object.max_budget
):
raise Exception(
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}"
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
)
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
@ -54,52 +63,79 @@ def common_checks(
return True
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
for allowed_route in allowed_routes:
if (
allowed_route == LiteLLMRoutes.openai_routes.name
and user_route in LiteLLMRoutes.openai_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.info_routes.name
and user_route in LiteLLMRoutes.info_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.management_routes.name
and user_route in LiteLLMRoutes.management_routes.value
):
return True
elif allowed_route == user_route:
return True
return False
def allowed_routes_check(
user_role: Literal["proxy_admin", "app_owner"],
route: str,
allowed_routes: Optional[list] = None,
user_role: Literal["proxy_admin", "team"],
user_route: str,
litellm_proxy_roles: LiteLLMProxyRoles,
) -> bool:
"""
Check if user -> not admin - allowed to access these routes
"""
openai_routes = [
# chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
default_routes = openai_routes + info_routes
if user_role == "proxy_admin":
return True
elif user_role == "app_owner":
if allowed_routes is None:
if route in default_routes: # check default routes
return True
elif route in allowed_routes:
return True
else:
return False
if litellm_proxy_roles.admin_allowed_routes is None:
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["management_routes"]
)
return is_allowed
elif litellm_proxy_roles.admin_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == "team":
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes
"""
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
)
return is_allowed
elif litellm_proxy_roles.team_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.team_allowed_routes,
)
return is_allowed
return False
def get_actual_routes(allowed_routes: list) -> list:
actual_routes: list = []
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
actual_routes = actual_routes + route_value
except KeyError:
actual_routes.append(route_name)
return actual_routes
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
@ -135,3 +171,75 @@ async def get_end_user_object(
return LiteLLM_EndUserTable(**response.dict())
except Exception as e: # if end-user not in db
return None
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
) -> LiteLLM_TeamTable:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTable):
return cached_team_obj
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if response is None:
raise Exception
return LiteLLM_TeamTable(**response.dict())
except Exception as e:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
)

View file

@ -81,57 +81,27 @@ class JWTHandler:
return len(parts) == 3
def is_admin(self, scopes: list) -> bool:
if self.litellm_proxy_roles.proxy_admin in scopes:
if self.litellm_proxy_roles.admin_jwt_scope in scopes:
return True
return False
def get_user_id(self, token: dict, default_value: str) -> str:
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try:
user_id = token["sub"]
if self.litellm_proxy_roles.team_id_jwt_field is not None:
user_id = token[self.litellm_proxy_roles.team_id_jwt_field]
else:
user_id = None
except KeyError:
user_id = default_value
return user_id
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
team_id = token["client_id"]
team_id = token[self.litellm_proxy_roles.team_id_jwt_field]
except KeyError:
team_id = default_value
return team_id
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
def get_scopes(self, token: dict) -> list:
try:
if isinstance(token["scope"], str):

View file

@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
get_team_object,
get_user_object,
allowed_routes_check,
get_actual_routes,
)
try:
@ -369,71 +372,93 @@ async def user_api_key_auth(
scopes = jwt_handler.get_scopes(token=valid_token)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# get user id
user_id = jwt_handler.get_user_id(
token=valid_token, default_value=litellm_proxy_admin_name
# if admin return
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role="proxy_admin",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
)
if is_allowed:
return UserAPIKeyAuth()
else:
allowed_routes = (
jwt_handler.litellm_proxy_roles.admin_allowed_routes
)
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# get team id
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
if team_id is None:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_proxy_roles.team_id_jwt_field}'"
)
# check allowed team routes
is_allowed = allowed_routes_check(
user_role="team",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_proxy_roles,
)
if is_allowed == False:
allowed_routes = jwt_handler.litellm_proxy_roles.team_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
end_user_object = None
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# common checks
# allow request
# get the request body
request_data = await _read_request_body(request=request)
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
user_object = await jwt_handler.get_user_object(user_id=user_id)
if (
request_data.get("user", None)
and request_data["user"] != user_object.user_id
):
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
token=valid_token, default_value=None
)
if end_user_id is not None:
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=request_data["user"], value=end_user_object
key=end_user_id, value=end_user_object
)
# run through common checks
_ = common_checks(
request_body=request_data,
user_object=user_object,
team_object=team_object,
end_user_object=end_user_object,
)
# save user object in cache
await user_api_key_cache.async_set_cache(
key=user_object.user_id, value=user_object
key=team_object.team_id, value=team_object
)
# if admin return
if is_admin:
return UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
else:
is_allowed = allowed_routes_check(
user_role="app_owner",
route=route,
allowed_routes=general_settings.get("allowed_routes", None),
)
if is_allowed:
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
user_id=user_object.user_id,
tpm_limit=user_object.tpm_limit,
rpm_limit=user_object.rpm_limit,
models=user_object.models,
team_id=team_object.team_id,
tpm_limit=team_object.tpm_limit,
rpm_limit=team_object.rpm_limit,
models=team_object.models,
user_role="app_owner",
)
else:
raise HTTPException(
status_code=401,
detail={
"error": f"User={user_object.user_id} not allowed to access this route={route}."
},
)
#### ELSE ####
if master_key is None:
if isinstance(api_key, str):
@ -2709,12 +2734,16 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ##
if general_settings.get("litellm_proxy_roles", None) is not None:
litellm_proxy_roles = LiteLLMProxyRoles(
**general_settings["litellm_proxy_roles"]
)
else:
litellm_proxy_roles = LiteLLMProxyRoles()
jwt_handler.update_environment(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_proxy_roles=LiteLLMProxyRoles(
**general_settings.get("litellm_proxy_roles", {})
),
litellm_proxy_roles=litellm_proxy_roles,
)
if use_background_health_checks: