diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index c8ec257047..0f877087c0 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -14,6 +14,11 @@ def hash_token(token: str): return hashed_token +class LiteLLMProxyRoles(enum.Enum): + PROXY_ADMIN = "litellm_proxy_admin" + USER = "litellm_user" + + class LiteLLMBase(BaseModel): """ Implements default functions, all pydantic objects should have. diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py new file mode 100644 index 0000000000..0276d2bef5 --- /dev/null +++ b/litellm/proxy/auth/handle_jwt.py @@ -0,0 +1,114 @@ +""" +Supports using JWT's for authenticating into the proxy. + +Currently only supports admin. + +JWT token must have 'litellm_proxy_admin' in scope. +""" + +import httpx +import jwt +import json +from jwt.algorithms import RSAAlgorithm +import os +from litellm.proxy._types import LiteLLMProxyRoles +from typing import Optional + + +class HTTPHandler: + def __init__(self): + self.client = httpx.AsyncClient() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.client.aclose() + + async def get( + self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None + ): + response = await self.client.get(url, params=params, headers=headers) + return response + + async def post( + self, + url: str, + data: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + ): + response = await self.client.post( + url, data=data, params=params, headers=headers + ) + return response + + +class JWTHandler: + + def __init__(self) -> None: + self.http_handler = HTTPHandler() + + def is_jwt(self, token: str): + parts = token.split(".") + return len(parts) == 3 + + def is_admin(self, scopes: list) -> bool: + if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes: + return True + return False + + def get_user_id(self, token: dict, default_value: str) -> str: + try: + user_id = token["sub"] + except KeyError: + user_id = default_value + return user_id + + def get_scopes(self, token: dict) -> list: + try: + # Assuming the scopes are stored in 'scope' claim and are space-separated + scopes = token["scope"].split() + except KeyError: + scopes = [] + return scopes + + async def auth_jwt(self, token: str) -> dict: + keys_url = os.getenv("OPENID_PUBLIC_KEY_URL") + + async with self.http_handler as http: + response = await http.get(keys_url) + + keys = response.json()["keys"] + + header = jwt.get_unverified_header(token) + kid = header["kid"] + + for key in keys: + if key["kid"] == kid: + jwk = { + "kty": key["kty"], + "kid": key["kid"], + "n": key["n"], + "e": key["e"], + } + public_key = RSAAlgorithm.from_jwk(json.dumps(jwk)) + + try: + # decode the token using the public key + payload = jwt.decode( + token, + public_key, # type: ignore + algorithms=["RS256"], + audience="account", + issuer=os.getenv("JWT_ISSUER"), + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") + + raise jwt.InvalidKeyError diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f718e8901f..fb1e4ecaae 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -106,6 +106,7 @@ from litellm.proxy._types import * from litellm.caching import DualCache from litellm.proxy.health_check import perform_health_check from litellm._logging import verbose_router_logger, verbose_proxy_logger +from litellm.proxy.auth.handle_jwt import JWTHandler try: from litellm._version import version @@ -282,6 +283,7 @@ proxy_budget_rescheduler_max_time = 605 proxy_batch_write_at = 60 # in seconds litellm_master_key_hash = None disable_spend_logs = False +jwt_handler = JWTHandler() ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) ### REDIS QUEUE ### @@ -334,6 +336,45 @@ async def user_api_key_auth( return UserAPIKeyAuth.model_validate(response) ### LITELLM-DEFINED AUTH FUNCTION ### + #### IF JWT #### + """ + LiteLLM supports using JWTs. + + Enable this in proxy config, by setting + ``` + general_settings: + enable_jwt_auth: true + ``` + """ + if general_settings.get("enable_jwt_auth", False) == True: + is_jwt = jwt_handler.is_jwt(token=api_key) + verbose_proxy_logger.debug(f"is_jwt: {is_jwt}") + if is_jwt: + # check if valid token + valid_token = await jwt_handler.auth_jwt(token=api_key) + # get scopes + scopes = jwt_handler.get_scopes(token=valid_token) + # check if admin + is_admin = jwt_handler.is_admin(scopes=scopes) + # get user id + user_id = jwt_handler.get_user_id( + token=valid_token, default_value=litellm_proxy_admin_name + ) + # if admin return + if is_admin: + _user_api_key_obj = UserAPIKeyAuth( + api_key=api_key, + user_role="proxy_admin", + user_id=user_id, + ) + user_api_key_cache.set_cache( + key=hash_token(api_key), value=_user_api_key_obj + ) + + return _user_api_key_obj + else: + raise Exception("Invalid key error!") + #### ELSE #### if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth(api_key=api_key) @@ -7531,7 +7572,27 @@ async def get_routes(): return {"routes": routes} -## TEST ENDPOINT +#### TEST ENDPOINTS #### +@router.get("/token/generate", dependencies=[Depends(user_api_key_auth)]) +async def token_generate(): + """ + Test endpoint. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc. + """ + # Initialize AuthJWTSSO with your OpenID Provider configuration + from fastapi_sso import AuthJWTSSO + + auth_jwt_sso = AuthJWTSSO( + issuer=os.getenv("OPENID_BASE_URL"), + client_id=os.getenv("OPENID_CLIENT_ID"), + client_secret=os.getenv("OPENID_CLIENT_SECRET"), + scopes=["litellm_proxy_admin"], + ) + + token = auth_jwt_sso.create_access_token() + + return {"token": token} + + # @router.post("/update_database", dependencies=[Depends(user_api_key_auth)]) # async def update_database_endpoint( # user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),