feat(handle_jwt.py): support authenticating admins into the proxy via jwt's

This commit is contained in:
Krrish Dholakia 2024-03-19 15:00:27 -07:00
parent 4913ad41db
commit 302bab6f1f
3 changed files with 181 additions and 1 deletions

View file

@ -106,6 +106,7 @@ from litellm.proxy._types import *
from litellm.caching import DualCache
from litellm.proxy.health_check import perform_health_check
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
try:
from litellm._version import version
@ -282,6 +283,7 @@ proxy_budget_rescheduler_max_time = 605
proxy_batch_write_at = 60 # in seconds
litellm_master_key_hash = None
disable_spend_logs = False
jwt_handler = JWTHandler()
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ###
@ -334,6 +336,45 @@ async def user_api_key_auth(
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
#### IF JWT ####
"""
LiteLLM supports using JWTs.
Enable this in proxy config, by setting
```
general_settings:
enable_jwt_auth: true
```
"""
if general_settings.get("enable_jwt_auth", False) == True:
is_jwt = jwt_handler.is_jwt(token=api_key)
verbose_proxy_logger.debug(f"is_jwt: {is_jwt}")
if is_jwt:
# check if valid token
valid_token = await jwt_handler.auth_jwt(token=api_key)
# get scopes
scopes = jwt_handler.get_scopes(token=valid_token)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# get user id
user_id = jwt_handler.get_user_id(
token=valid_token, default_value=litellm_proxy_admin_name
)
# if admin return
if is_admin:
_user_api_key_obj = UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
user_api_key_cache.set_cache(
key=hash_token(api_key), value=_user_api_key_obj
)
return _user_api_key_obj
else:
raise Exception("Invalid key error!")
#### ELSE ####
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key)
@ -7531,7 +7572,27 @@ async def get_routes():
return {"routes": routes}
## TEST ENDPOINT
#### TEST ENDPOINTS ####
@router.get("/token/generate", dependencies=[Depends(user_api_key_auth)])
async def token_generate():
"""
Test endpoint. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
"""
# Initialize AuthJWTSSO with your OpenID Provider configuration
from fastapi_sso import AuthJWTSSO
auth_jwt_sso = AuthJWTSSO(
issuer=os.getenv("OPENID_BASE_URL"),
client_id=os.getenv("OPENID_CLIENT_ID"),
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
scopes=["litellm_proxy_admin"],
)
token = auth_jwt_sso.create_access_token()
return {"token": token}
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
# async def update_database_endpoint(
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),