forked from phoenix/litellm-mirror
Merge pull request #2704 from BerriAI/litellm_jwt_auth_improvements_3
fix(handle_jwt.py): enable team-based jwt-auth access
This commit is contained in:
commit
0ab708e6f1
7 changed files with 358 additions and 145 deletions
|
@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
|
|||
from litellm.proxy.auth.auth_checks import (
|
||||
common_checks,
|
||||
get_end_user_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -359,71 +362,99 @@ async def user_api_key_auth(
|
|||
scopes = jwt_handler.get_scopes(token=valid_token)
|
||||
# check if admin
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
# get user id
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=valid_token, default_value=litellm_proxy_admin_name
|
||||
# if admin return
|
||||
if is_admin:
|
||||
# check allowed admin routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role="proxy_admin",
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed:
|
||||
return UserAPIKeyAuth()
|
||||
else:
|
||||
allowed_routes = (
|
||||
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
)
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
# check if team in scopes
|
||||
is_team = jwt_handler.is_team(scopes=scopes)
|
||||
if is_team == False:
|
||||
raise Exception(
|
||||
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}"
|
||||
)
|
||||
# get team id
|
||||
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
|
||||
|
||||
if team_id is None:
|
||||
raise Exception(
|
||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
# check allowed team routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role="team",
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed == False:
|
||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# check if team in db
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
end_user_object = None
|
||||
# common checks
|
||||
# allow request
|
||||
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
|
||||
user_object = await jwt_handler.get_user_object(user_id=user_id)
|
||||
if (
|
||||
request_data.get("user", None)
|
||||
and request_data["user"] != user_object.user_id
|
||||
):
|
||||
|
||||
end_user_object = None
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=valid_token, default_value=None
|
||||
)
|
||||
if end_user_id is not None:
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=request_data["user"],
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# save the end-user object to cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=request_data["user"], value=end_user_object
|
||||
key=end_user_id, value=end_user_object
|
||||
)
|
||||
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
user_object=user_object,
|
||||
team_object=team_object,
|
||||
end_user_object=end_user_object,
|
||||
)
|
||||
# save user object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_object.user_id, value=user_object
|
||||
key=team_object.team_id, value=team_object
|
||||
)
|
||||
|
||||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
team_id=team_object.team_id,
|
||||
tpm_limit=team_object.tpm_limit,
|
||||
rpm_limit=team_object.rpm_limit,
|
||||
models=team_object.models,
|
||||
user_role="app_owner",
|
||||
)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role="proxy_admin",
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role="app_owner",
|
||||
route=route,
|
||||
allowed_routes=general_settings.get("allowed_routes", None),
|
||||
)
|
||||
if is_allowed:
|
||||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
user_id=user_object.user_id,
|
||||
tpm_limit=user_object.tpm_limit,
|
||||
rpm_limit=user_object.rpm_limit,
|
||||
models=user_object.models,
|
||||
user_role="app_owner",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"User={user_object.user_id} not allowed to access this route={route}."
|
||||
},
|
||||
)
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
|
@ -2698,12 +2729,14 @@ async def startup_event():
|
|||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
## JWT AUTH ##
|
||||
if general_settings.get("litellm_jwtauth", None) is not None:
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
||||
else:
|
||||
litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
litellm_proxy_roles=LiteLLMProxyRoles(
|
||||
**general_settings.get("litellm_proxy_roles", {})
|
||||
),
|
||||
litellm_jwtauth=litellm_jwtauth,
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue