feat(handle_jwt.py): support authenticating admins into the proxy via jwt's

This commit is contained in:
Krrish Dholakia 2024-03-19 15:00:27 -07:00
parent 4913ad41db
commit 302bab6f1f
3 changed files with 181 additions and 1 deletions

View file

@ -14,6 +14,11 @@ def hash_token(token: str):
return hashed_token
class LiteLLMProxyRoles(enum.Enum):
PROXY_ADMIN = "litellm_proxy_admin"
USER = "litellm_user"
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.

View file

@ -0,0 +1,114 @@
"""
Supports using JWT's for authenticating into the proxy.
Currently only supports admin.
JWT token must have 'litellm_proxy_admin' in scope.
"""
import httpx
import jwt
import json
from jwt.algorithms import RSAAlgorithm
import os
from litellm.proxy._types import LiteLLMProxyRoles
from typing import Optional
class HTTPHandler:
def __init__(self):
self.client = httpx.AsyncClient()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
response = await self.client.post(
url, data=data, params=params, headers=headers
)
return response
class JWTHandler:
def __init__(self) -> None:
self.http_handler = HTTPHandler()
def is_jwt(self, token: str):
parts = token.split(".")
return len(parts) == 3
def is_admin(self, scopes: list) -> bool:
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes:
return True
return False
def get_user_id(self, token: dict, default_value: str) -> str:
try:
user_id = token["sub"]
except KeyError:
user_id = default_value
return user_id
def get_scopes(self, token: dict) -> list:
try:
# Assuming the scopes are stored in 'scope' claim and are space-separated
scopes = token["scope"].split()
except KeyError:
scopes = []
return scopes
async def auth_jwt(self, token: str) -> dict:
keys_url = os.getenv("OPENID_PUBLIC_KEY_URL")
async with self.http_handler as http:
response = await http.get(keys_url)
keys = response.json()["keys"]
header = jwt.get_unverified_header(token)
kid = header["kid"]
for key in keys:
if key["kid"] == kid:
jwk = {
"kty": key["kty"],
"kid": key["kid"],
"n": key["n"],
"e": key["e"],
}
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
try:
# decode the token using the public key
payload = jwt.decode(
token,
public_key, # type: ignore
algorithms=["RS256"],
audience="account",
issuer=os.getenv("JWT_ISSUER"),
)
return payload
except jwt.ExpiredSignatureError:
# the token is expired, do something to refresh it
raise Exception("Token Expired")
except Exception as e:
raise Exception(f"Validation fails: {str(e)}")
raise jwt.InvalidKeyError

View file

@ -106,6 +106,7 @@ from litellm.proxy._types import *
from litellm.caching import DualCache
from litellm.proxy.health_check import perform_health_check
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
try:
from litellm._version import version
@ -282,6 +283,7 @@ proxy_budget_rescheduler_max_time = 605
proxy_batch_write_at = 60 # in seconds
litellm_master_key_hash = None
disable_spend_logs = False
jwt_handler = JWTHandler()
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ###
@ -334,6 +336,45 @@ async def user_api_key_auth(
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
#### IF JWT ####
"""
LiteLLM supports using JWTs.
Enable this in proxy config, by setting
```
general_settings:
enable_jwt_auth: true
```
"""
if general_settings.get("enable_jwt_auth", False) == True:
is_jwt = jwt_handler.is_jwt(token=api_key)
verbose_proxy_logger.debug(f"is_jwt: {is_jwt}")
if is_jwt:
# check if valid token
valid_token = await jwt_handler.auth_jwt(token=api_key)
# get scopes
scopes = jwt_handler.get_scopes(token=valid_token)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# get user id
user_id = jwt_handler.get_user_id(
token=valid_token, default_value=litellm_proxy_admin_name
)
# if admin return
if is_admin:
_user_api_key_obj = UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
user_api_key_cache.set_cache(
key=hash_token(api_key), value=_user_api_key_obj
)
return _user_api_key_obj
else:
raise Exception("Invalid key error!")
#### ELSE ####
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key)
@ -7531,7 +7572,27 @@ async def get_routes():
return {"routes": routes}
## TEST ENDPOINT
#### TEST ENDPOINTS ####
@router.get("/token/generate", dependencies=[Depends(user_api_key_auth)])
async def token_generate():
"""
Test endpoint. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
"""
# Initialize AuthJWTSSO with your OpenID Provider configuration
from fastapi_sso import AuthJWTSSO
auth_jwt_sso = AuthJWTSSO(
issuer=os.getenv("OPENID_BASE_URL"),
client_id=os.getenv("OPENID_CLIENT_ID"),
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
scopes=["litellm_proxy_admin"],
)
token = auth_jwt_sso.create_access_token()
return {"token": token}
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
# async def update_database_endpoint(
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),