Merge pull request #2704 from BerriAI/litellm_jwt_auth_improvements_3

fix(handle_jwt.py): enable team-based jwt-auth access
This commit is contained in:
Krish Dholakia 2024-03-26 16:06:56 -07:00 committed by GitHub
commit 0ab708e6f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 358 additions and 145 deletions

View file

@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
get_team_object,
get_user_object,
allowed_routes_check,
get_actual_routes,
)
try:
@ -359,71 +362,99 @@ async def user_api_key_auth(
scopes = jwt_handler.get_scopes(token=valid_token)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# get user id
user_id = jwt_handler.get_user_id(
token=valid_token, default_value=litellm_proxy_admin_name
# if admin return
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role="proxy_admin",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return UserAPIKeyAuth()
else:
allowed_routes = (
jwt_handler.litellm_jwtauth.admin_allowed_routes
)
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in scopes
is_team = jwt_handler.is_team(scopes=scopes)
if is_team == False:
raise Exception(
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}"
)
# get team id
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
if team_id is None:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
# check allowed team routes
is_allowed = allowed_routes_check(
user_role="team",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed == False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
end_user_object = None
# common checks
# allow request
# get the request body
request_data = await _read_request_body(request=request)
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
user_object = await jwt_handler.get_user_object(user_id=user_id)
if (
request_data.get("user", None)
and request_data["user"] != user_object.user_id
):
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
token=valid_token, default_value=None
)
if end_user_id is not None:
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=request_data["user"], value=end_user_object
key=end_user_id, value=end_user_object
)
# run through common checks
_ = common_checks(
request_body=request_data,
user_object=user_object,
team_object=team_object,
end_user_object=end_user_object,
)
# save user object in cache
await user_api_key_cache.async_set_cache(
key=user_object.user_id, value=user_object
key=team_object.team_id, value=team_object
)
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
team_id=team_object.team_id,
tpm_limit=team_object.tpm_limit,
rpm_limit=team_object.rpm_limit,
models=team_object.models,
user_role="app_owner",
)
# if admin return
if is_admin:
return UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
else:
is_allowed = allowed_routes_check(
user_role="app_owner",
route=route,
allowed_routes=general_settings.get("allowed_routes", None),
)
if is_allowed:
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
user_id=user_object.user_id,
tpm_limit=user_object.tpm_limit,
rpm_limit=user_object.rpm_limit,
models=user_object.models,
user_role="app_owner",
)
else:
raise HTTPException(
status_code=401,
detail={
"error": f"User={user_object.user_id} not allowed to access this route={route}."
},
)
#### ELSE ####
if master_key is None:
if isinstance(api_key, str):
@ -2698,12 +2729,14 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ##
if general_settings.get("litellm_jwtauth", None) is not None:
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
else:
litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.update_environment(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_proxy_roles=LiteLLMProxyRoles(
**general_settings.get("litellm_proxy_roles", {})
),
litellm_jwtauth=litellm_jwtauth,
)
if use_background_health_checks: