forked from phoenix/litellm-mirror
feat(handle_jwt.py): support authenticating admins into the proxy via jwt's
This commit is contained in:
parent
4913ad41db
commit
302bab6f1f
3 changed files with 181 additions and 1 deletions
|
@ -14,6 +14,11 @@ def hash_token(token: str):
|
|||
return hashed_token
|
||||
|
||||
|
||||
class LiteLLMProxyRoles(enum.Enum):
|
||||
PROXY_ADMIN = "litellm_proxy_admin"
|
||||
USER = "litellm_user"
|
||||
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
|
|
114
litellm/proxy/auth/handle_jwt.py
Normal file
114
litellm/proxy/auth/handle_jwt.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
Supports using JWT's for authenticating into the proxy.
|
||||
|
||||
Currently only supports admin.
|
||||
|
||||
JWT token must have 'litellm_proxy_admin' in scope.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import json
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
import os
|
||||
from litellm.proxy._types import LiteLLMProxyRoles
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(self):
|
||||
self.client = httpx.AsyncClient()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = await self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
response = await self.client.post(
|
||||
url, data=data, params=params, headers=headers
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.http_handler = HTTPHandler()
|
||||
|
||||
def is_jwt(self, token: str):
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def is_admin(self, scopes: list) -> bool:
|
||||
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_user_id(self, token: dict, default_value: str) -> str:
|
||||
try:
|
||||
user_id = token["sub"]
|
||||
except KeyError:
|
||||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_scopes(self, token: dict) -> list:
|
||||
try:
|
||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||
scopes = token["scope"].split()
|
||||
except KeyError:
|
||||
scopes = []
|
||||
return scopes
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
keys_url = os.getenv("OPENID_PUBLIC_KEY_URL")
|
||||
|
||||
async with self.http_handler as http:
|
||||
response = await http.get(keys_url)
|
||||
|
||||
keys = response.json()["keys"]
|
||||
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header["kid"]
|
||||
|
||||
for key in keys:
|
||||
if key["kid"] == kid:
|
||||
jwk = {
|
||||
"kty": key["kty"],
|
||||
"kid": key["kid"],
|
||||
"n": key["n"],
|
||||
"e": key["e"],
|
||||
}
|
||||
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||
|
||||
try:
|
||||
# decode the token using the public key
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_key, # type: ignore
|
||||
algorithms=["RS256"],
|
||||
audience="account",
|
||||
issuer=os.getenv("JWT_ISSUER"),
|
||||
)
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
# the token is expired, do something to refresh it
|
||||
raise Exception("Token Expired")
|
||||
except Exception as e:
|
||||
raise Exception(f"Validation fails: {str(e)}")
|
||||
|
||||
raise jwt.InvalidKeyError
|
|
@ -106,6 +106,7 @@ from litellm.proxy._types import *
|
|||
from litellm.caching import DualCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -282,6 +283,7 @@ proxy_budget_rescheduler_max_time = 605
|
|||
proxy_batch_write_at = 60 # in seconds
|
||||
litellm_master_key_hash = None
|
||||
disable_spend_logs = False
|
||||
jwt_handler = JWTHandler()
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
### REDIS QUEUE ###
|
||||
|
@ -334,6 +336,45 @@ async def user_api_key_auth(
|
|||
return UserAPIKeyAuth.model_validate(response)
|
||||
|
||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||
#### IF JWT ####
|
||||
"""
|
||||
LiteLLM supports using JWTs.
|
||||
|
||||
Enable this in proxy config, by setting
|
||||
```
|
||||
general_settings:
|
||||
enable_jwt_auth: true
|
||||
```
|
||||
"""
|
||||
if general_settings.get("enable_jwt_auth", False) == True:
|
||||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||
verbose_proxy_logger.debug(f"is_jwt: {is_jwt}")
|
||||
if is_jwt:
|
||||
# check if valid token
|
||||
valid_token = await jwt_handler.auth_jwt(token=api_key)
|
||||
# get scopes
|
||||
scopes = jwt_handler.get_scopes(token=valid_token)
|
||||
# check if admin
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
# get user id
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=valid_token, default_value=litellm_proxy_admin_name
|
||||
)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
_user_api_key_obj = UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role="proxy_admin",
|
||||
user_id=user_id,
|
||||
)
|
||||
user_api_key_cache.set_cache(
|
||||
key=hash_token(api_key), value=_user_api_key_obj
|
||||
)
|
||||
|
||||
return _user_api_key_obj
|
||||
else:
|
||||
raise Exception("Invalid key error!")
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
|
@ -7531,7 +7572,27 @@ async def get_routes():
|
|||
return {"routes": routes}
|
||||
|
||||
|
||||
## TEST ENDPOINT
|
||||
#### TEST ENDPOINTS ####
|
||||
@router.get("/token/generate", dependencies=[Depends(user_api_key_auth)])
|
||||
async def token_generate():
|
||||
"""
|
||||
Test endpoint. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
||||
"""
|
||||
# Initialize AuthJWTSSO with your OpenID Provider configuration
|
||||
from fastapi_sso import AuthJWTSSO
|
||||
|
||||
auth_jwt_sso = AuthJWTSSO(
|
||||
issuer=os.getenv("OPENID_BASE_URL"),
|
||||
client_id=os.getenv("OPENID_CLIENT_ID"),
|
||||
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
||||
scopes=["litellm_proxy_admin"],
|
||||
)
|
||||
|
||||
token = auth_jwt_sso.create_access_token()
|
||||
|
||||
return {"token": token}
|
||||
|
||||
|
||||
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
|
||||
# async def update_database_endpoint(
|
||||
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue