diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 14631a2d9..5f1812757 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -107,4 +107,38 @@ general_settings: master_key: sk-1234 enable_jwt_auth: True allowed_routes: ["/chat/completions", "/embeddings"] +``` + +## Advanced - Set Accepted JWT Scope Names + +Change the string in JWT 'scopes', that litellm evaluates to see if a user has admin access. + +```yaml +general_settings: + master_key: sk-1234 + enable_jwt_auth: True + litellm_proxy_roles: + proxy_admin: "litellm-proxy-admin" +``` + +### Allowed LiteLLM scopes + +```python +class LiteLLMProxyRoles(LiteLLMBase): + proxy_admin: str = "litellm_proxy_admin" + proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth. +``` + +### JWT Scopes + +Here's what scopes on JWT-Auth tokens look like + +**Can be a list** +``` +scope: ["litellm-proxy-admin",...] +``` + +**Can be a space-separated string** +``` +scope: "litellm-proxy-admin ..." ``` \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d23049056..d4e5834f2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -14,11 +14,6 @@ def hash_token(token: str): return hashed_token -class LiteLLMProxyRoles(enum.Enum): - PROXY_ADMIN = "litellm_proxy_admin" - USER = "litellm_user" - - class LiteLLMBase(BaseModel): """ Implements default functions, all pydantic objects should have. @@ -42,6 +37,11 @@ class LiteLLMBase(BaseModel): protected_namespaces = () +class LiteLLMProxyRoles(LiteLLMBase): + proxy_admin: str = "litellm_proxy_admin" + proxy_user: str = "litellm_user" + + class LiteLLMPromptInjectionParams(LiteLLMBase): heuristics_check: bool = False vector_db_check: bool = False diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index ae0ef85f8..b636c8813 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -67,17 +67,21 @@ class JWTHandler: self.http_handler = HTTPHandler() def update_environment( - self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache + self, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + litellm_proxy_roles: LiteLLMProxyRoles, ) -> None: self.prisma_client = prisma_client self.user_api_key_cache = user_api_key_cache + self.litellm_proxy_roles = litellm_proxy_roles def is_jwt(self, token: str): parts = token.split(".") return len(parts) == 3 def is_admin(self, scopes: list) -> bool: - if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes: + if self.litellm_proxy_roles.proxy_admin in scopes: return True return False @@ -90,7 +94,7 @@ class JWTHandler: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: - team_id = token["azp"] + team_id = token["client_id"] except KeyError: team_id = default_value return team_id @@ -130,58 +134,94 @@ class JWTHandler: def get_scopes(self, token: dict) -> list: try: - # Assuming the scopes are stored in 'scope' claim and are space-separated - scopes = token["scope"].split() + if isinstance(token["scope"], str): + # Assuming the scopes are stored in 'scope' claim and are space-separated + scopes = token["scope"].split() + elif isinstance(token["scope"], list): + scopes = token["scope"] + else: + raise Exception( + f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str." + ) except KeyError: scopes = [] return scopes - async def auth_jwt(self, token: str) -> dict: - from jwt.algorithms import RSAAlgorithm - + async def get_public_key(self, kid: Optional[str]) -> dict: keys_url = os.getenv("JWT_PUBLIC_KEY_URL") if keys_url is None: raise Exception("Missing JWT Public Key URL from environment.") - response = await self.http_handler.get(keys_url) + cached_keys = await self.user_api_key_cache.async_get_cache( + "litellm_jwt_auth_keys" + ) + if cached_keys is None: + response = await self.http_handler.get(keys_url) - keys = response.json()["keys"] + keys = response.json()["keys"] + + await self.user_api_key_cache.async_set_cache( + key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins + ) + else: + keys = cached_keys + + public_key: Optional[dict] = None + + if len(keys) == 1: + public_key = keys[0] + elif len(keys) > 1: + for key in keys: + if kid is not None and key["kid"] == kid: + public_key = key + + if public_key is None: + raise Exception( + f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}" + ) + + return public_key + + async def auth_jwt(self, token: str) -> dict: + from jwt.algorithms import RSAAlgorithm header = jwt.get_unverified_header(token) verbose_proxy_logger.debug("header: %s", header) - if "kid" in header: - kid = header["kid"] - else: - raise Exception(f"Expected 'kid' in header. header={header}.") + kid = header.get("kid", None) - for key in keys: - if key["kid"] == kid: - jwk = { - "kty": key["kty"], - "kid": key["kid"], - "n": key["n"], - "e": key["e"], - } - public_key = RSAAlgorithm.from_jwk(json.dumps(jwk)) + public_key = await self.get_public_key(kid=kid) - try: - # decode the token using the public key - payload = jwt.decode( - token, - public_key, # type: ignore - algorithms=["RS256"], - audience="account", - ) - return payload + if public_key is not None and isinstance(public_key, dict): + jwk = {} + if "kty" in public_key: + jwk["kty"] = public_key["kty"] + if "kid" in public_key: + jwk["kid"] = public_key["kid"] + if "n" in public_key: + jwk["n"] = public_key["n"] + if "e" in public_key: + jwk["e"] = public_key["e"] - except jwt.ExpiredSignatureError: - # the token is expired, do something to refresh it - raise Exception("Token Expired") - except Exception as e: - raise Exception(f"Validation fails: {str(e)}") + public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk)) + + try: + # decode the token using the public key + payload = jwt.decode( + token, + public_key_rsa, # type: ignore + algorithms=["RS256"], + options={"verify_aud": False}, + ) + return payload + + except jwt.ExpiredSignatureError: + # the token is expired, do something to refresh it + raise Exception("Token Expired") + except Exception as e: + raise Exception(f"Validation fails: {str(e)}") raise Exception("Invalid JWT Submitted") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c24586d76..44aed9fe4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2710,7 +2710,11 @@ async def startup_event(): ## JWT AUTH ## jwt_handler.update_environment( - prisma_client=prisma_client, user_api_key_cache=user_api_key_cache + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + litellm_proxy_roles=LiteLLMProxyRoles( + **general_settings.get("litellm_proxy_roles", {}) + ), ) if use_background_health_checks: diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py new file mode 100644 index 000000000..57c7e5c62 --- /dev/null +++ b/litellm/tests/test_jwt.py @@ -0,0 +1,179 @@ +#### What this tests #### +# Unit tests for JWT-Auth + +import sys, os, asyncio, time, random +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm.proxy._types import LiteLLMProxyRoles +from litellm.proxy.auth.handle_jwt import JWTHandler +from litellm.caching import DualCache +from datetime import datetime, timedelta + +public_key = { + "kty": "RSA", + "e": "AQAB", + "n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", + "alg": "RS256", +} + + +def test_load_config_with_custom_role_names(): + config = { + "general_settings": { + "litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"} + } + } + + proxy_roles = LiteLLMProxyRoles( + **config.get("general_settings", {}).get("litellm_proxy_roles", {}) + ) + + print(f"proxy_roles: {proxy_roles}") + + assert proxy_roles.proxy_admin == "litellm-proxy-admin" + + +# test_load_config_with_custom_role_names() + + +@pytest.mark.asyncio +async def test_token_single_public_key(): + import jwt + + jwt_handler = JWTHandler() + + backend_keys = { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "e": "AQAB", + "n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", + "alg": "RS256", + } + ] + } + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"]) + + jwt_handler.user_api_key_cache = cache + + public_key = await jwt_handler.get_public_key(kid=None) + + assert public_key is not None + assert isinstance(public_key, dict) + assert ( + public_key["n"] + == "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" + ) + + +@pytest.mark.asyncio +async def test_valid_invalid_token(): + """ + Tests + - valid token + - invalid token + """ + import jwt, json + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.backends import default_backend + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + assert isinstance(public_jwk, dict) + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm-proxy-admin", + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + + # verify token + + response = await jwt_handler.auth_jwt(token=token) + + assert response is not None + assert isinstance(response, dict) + + print(f"response: {response}") + + # INVALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm-NO-SCOPE", + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + + # verify token + + try: + response = await jwt_handler.auth_jwt(token=token) + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}")