forked from phoenix/litellm-mirror
feat(handle_jwt.py): support authenticating admins into the proxy via jwt's
This commit is contained in:
parent
4913ad41db
commit
302bab6f1f
3 changed files with 181 additions and 1 deletions
|
@ -14,6 +14,11 @@ def hash_token(token: str):
|
||||||
return hashed_token
|
return hashed_token
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMProxyRoles(enum.Enum):
|
||||||
|
PROXY_ADMIN = "litellm_proxy_admin"
|
||||||
|
USER = "litellm_user"
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(BaseModel):
|
class LiteLLMBase(BaseModel):
|
||||||
"""
|
"""
|
||||||
Implements default functions, all pydantic objects should have.
|
Implements default functions, all pydantic objects should have.
|
||||||
|
|
114
litellm/proxy/auth/handle_jwt.py
Normal file
114
litellm/proxy/auth/handle_jwt.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
"""
|
||||||
|
Supports using JWT's for authenticating into the proxy.
|
||||||
|
|
||||||
|
Currently only supports admin.
|
||||||
|
|
||||||
|
JWT token must have 'litellm_proxy_admin' in scope.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt
|
||||||
|
import json
|
||||||
|
from jwt.algorithms import RSAAlgorithm
|
||||||
|
import os
|
||||||
|
from litellm.proxy._types import LiteLLMProxyRoles
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPHandler:
|
||||||
|
def __init__(self):
|
||||||
|
self.client = httpx.AsyncClient()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||||
|
):
|
||||||
|
response = await self.client.get(url, params=params, headers=headers)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def post(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
data: Optional[dict] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
response = await self.client.post(
|
||||||
|
url, data=data, params=params, headers=headers
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class JWTHandler:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.http_handler = HTTPHandler()
|
||||||
|
|
||||||
|
def is_jwt(self, token: str):
|
||||||
|
parts = token.split(".")
|
||||||
|
return len(parts) == 3
|
||||||
|
|
||||||
|
def is_admin(self, scopes: list) -> bool:
|
||||||
|
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_user_id(self, token: dict, default_value: str) -> str:
|
||||||
|
try:
|
||||||
|
user_id = token["sub"]
|
||||||
|
except KeyError:
|
||||||
|
user_id = default_value
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
def get_scopes(self, token: dict) -> list:
|
||||||
|
try:
|
||||||
|
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||||
|
scopes = token["scope"].split()
|
||||||
|
except KeyError:
|
||||||
|
scopes = []
|
||||||
|
return scopes
|
||||||
|
|
||||||
|
async def auth_jwt(self, token: str) -> dict:
|
||||||
|
keys_url = os.getenv("OPENID_PUBLIC_KEY_URL")
|
||||||
|
|
||||||
|
async with self.http_handler as http:
|
||||||
|
response = await http.get(keys_url)
|
||||||
|
|
||||||
|
keys = response.json()["keys"]
|
||||||
|
|
||||||
|
header = jwt.get_unverified_header(token)
|
||||||
|
kid = header["kid"]
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key["kid"] == kid:
|
||||||
|
jwk = {
|
||||||
|
"kty": key["kty"],
|
||||||
|
"kid": key["kid"],
|
||||||
|
"n": key["n"],
|
||||||
|
"e": key["e"],
|
||||||
|
}
|
||||||
|
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# decode the token using the public key
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
public_key, # type: ignore
|
||||||
|
algorithms=["RS256"],
|
||||||
|
audience="account",
|
||||||
|
issuer=os.getenv("JWT_ISSUER"),
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
# the token is expired, do something to refresh it
|
||||||
|
raise Exception("Token Expired")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Validation fails: {str(e)}")
|
||||||
|
|
||||||
|
raise jwt.InvalidKeyError
|
|
@ -106,6 +106,7 @@ from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
|
@ -282,6 +283,7 @@ proxy_budget_rescheduler_max_time = 605
|
||||||
proxy_batch_write_at = 60 # in seconds
|
proxy_batch_write_at = 60 # in seconds
|
||||||
litellm_master_key_hash = None
|
litellm_master_key_hash = None
|
||||||
disable_spend_logs = False
|
disable_spend_logs = False
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
|
@ -334,6 +336,45 @@ async def user_api_key_auth(
|
||||||
return UserAPIKeyAuth.model_validate(response)
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
|
|
||||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||||
|
#### IF JWT ####
|
||||||
|
"""
|
||||||
|
LiteLLM supports using JWTs.
|
||||||
|
|
||||||
|
Enable this in proxy config, by setting
|
||||||
|
```
|
||||||
|
general_settings:
|
||||||
|
enable_jwt_auth: true
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if general_settings.get("enable_jwt_auth", False) == True:
|
||||||
|
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||||
|
verbose_proxy_logger.debug(f"is_jwt: {is_jwt}")
|
||||||
|
if is_jwt:
|
||||||
|
# check if valid token
|
||||||
|
valid_token = await jwt_handler.auth_jwt(token=api_key)
|
||||||
|
# get scopes
|
||||||
|
scopes = jwt_handler.get_scopes(token=valid_token)
|
||||||
|
# check if admin
|
||||||
|
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||||
|
# get user id
|
||||||
|
user_id = jwt_handler.get_user_id(
|
||||||
|
token=valid_token, default_value=litellm_proxy_admin_name
|
||||||
|
)
|
||||||
|
# if admin return
|
||||||
|
if is_admin:
|
||||||
|
_user_api_key_obj = UserAPIKeyAuth(
|
||||||
|
api_key=api_key,
|
||||||
|
user_role="proxy_admin",
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(
|
||||||
|
key=hash_token(api_key), value=_user_api_key_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
return _user_api_key_obj
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid key error!")
|
||||||
|
#### ELSE ####
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
return UserAPIKeyAuth(api_key=api_key)
|
return UserAPIKeyAuth(api_key=api_key)
|
||||||
|
@ -7531,7 +7572,27 @@ async def get_routes():
|
||||||
return {"routes": routes}
|
return {"routes": routes}
|
||||||
|
|
||||||
|
|
||||||
## TEST ENDPOINT
|
#### TEST ENDPOINTS ####
|
||||||
|
@router.get("/token/generate", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def token_generate():
|
||||||
|
"""
|
||||||
|
Test endpoint. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
||||||
|
"""
|
||||||
|
# Initialize AuthJWTSSO with your OpenID Provider configuration
|
||||||
|
from fastapi_sso import AuthJWTSSO
|
||||||
|
|
||||||
|
auth_jwt_sso = AuthJWTSSO(
|
||||||
|
issuer=os.getenv("OPENID_BASE_URL"),
|
||||||
|
client_id=os.getenv("OPENID_CLIENT_ID"),
|
||||||
|
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
||||||
|
scopes=["litellm_proxy_admin"],
|
||||||
|
)
|
||||||
|
|
||||||
|
token = auth_jwt_sso.create_access_token()
|
||||||
|
|
||||||
|
return {"token": token}
|
||||||
|
|
||||||
|
|
||||||
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
|
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
|
||||||
# async def update_database_endpoint(
|
# async def update_database_endpoint(
|
||||||
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue