diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 248d553662..cc41050198 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -345,32 +345,38 @@ class JWTHandler: if keys_url is None: raise Exception("Missing JWT Public Key URL from environment.") - cached_keys = await self.user_api_key_cache.async_get_cache( - "litellm_jwt_auth_keys" - ) - if cached_keys is None: - response = await self.http_handler.get(keys_url) + keys_url_list = [url.strip() for url in keys_url.split(",")] - response_json = response.json() - if "keys" in response_json: - keys: JWKKeyValue = response.json()["keys"] + for key_url in keys_url_list: + + cache_key = f"litellm_jwt_auth_keys_{key_url}" + + cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) + + if cached_keys is None: + response = await self.http_handler.get(key_url) + + response_json = response.json() + if "keys" in response_json: + keys: JWKKeyValue = response.json()["keys"] + else: + keys = response_json + + await self.user_api_key_cache.async_set_cache( + key=cache_key, + value=keys, + ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins + ) else: - keys = response_json + keys = cached_keys - await self.user_api_key_cache.async_set_cache( - key="litellm_jwt_auth_keys", - value=keys, - ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins - ) - else: - keys = cached_keys + public_key = self.parse_keys(keys=keys, kid=kid) + if public_key is not None: + return cast(dict, public_key) - public_key = self.parse_keys(keys=keys, kid=kid) - if public_key is None: - raise Exception( - f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}" - ) - return cast(dict, public_key) + raise Exception( + f"No matching public key found. keys={keys_url_list}, kid={kid}" + ) def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]: public_key: Optional[JWTKeyItem] = None diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index 7a9d2f0019..d96fb691f7 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -64,7 +64,7 @@ def test_load_config_with_custom_role_names(): @pytest.mark.asyncio -async def test_token_single_public_key(): +async def test_token_single_public_key(monkeypatch): import jwt jwt_handler = JWTHandler() @@ -80,10 +80,15 @@ async def test_token_single_public_key(): ] } + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", + value=backend_keys["keys"], + ) jwt_handler.user_api_key_cache = cache @@ -99,7 +104,7 @@ async def test_token_single_public_key(): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_valid_invalid_token(audience): +async def test_valid_invalid_token(audience, monkeypatch): """ Tests - valid token @@ -116,6 +121,8 @@ async def test_valid_invalid_token(audience): if audience: os.environ["JWT_AUDIENCE"] = audience + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -145,7 +152,9 @@ async def test_valid_invalid_token(audience): # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -294,7 +303,7 @@ def team_token_tuple(): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_team_token_output(prisma_client, audience): +async def test_team_token_output(prisma_client, audience, monkeypatch): import json import uuid @@ -316,6 +325,8 @@ async def test_team_token_output(prisma_client, audience): if audience: os.environ["JWT_AUDIENCE"] = audience + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -345,7 +356,9 @@ async def test_team_token_output(prisma_client, audience): # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -463,7 +476,7 @@ async def test_team_token_output(prisma_client, audience): @pytest.mark.parametrize("user_id_upsert", [True, False]) @pytest.mark.asyncio async def aaaatest_user_token_output( - prisma_client, audience, team_id_set, default_team_id, user_id_upsert + prisma_client, audience, team_id_set, default_team_id, user_id_upsert, monkeypatch ): import uuid @@ -528,10 +541,14 @@ async def aaaatest_user_token_output( assert isinstance(public_jwk, dict) + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -699,7 +716,9 @@ async def aaaatest_user_token_output( @pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]]) @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes): +async def test_allowed_routes_admin( + prisma_client, audience, admin_allowed_routes, monkeypatch +): """ Add a check to make sure jwt proxy admin scope can access all allowed admin routes @@ -723,6 +742,8 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + os.environ.pop("JWT_AUDIENCE", None) if audience: os.environ["JWT_AUDIENCE"] = audience @@ -756,7 +777,9 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -910,7 +933,9 @@ def mock_user_object(*args, **kwargs): "user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)] ) @pytest.mark.asyncio -async def test_allow_access_by_email(public_jwt_key, user_email, should_work): +async def test_allow_access_by_email( + public_jwt_key, user_email, should_work, monkeypatch +): """ Allow anyone with an `@xyz.com` email make a request to the proxy. @@ -925,10 +950,14 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work): public_jwk = public_jwt_key["public_jwk"] private_key = public_jwt_key["private_key"] + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -1074,7 +1103,7 @@ async def test_end_user_jwt_auth(monkeypatch): ] cache.set_cache( - key="litellm_jwt_auth_keys", + key="litellm_jwt_auth_keys_https://example.com/public-key", value=keys, ) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index dbe49a560d..e956a22282 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -826,7 +826,7 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa ] local_cache.set_cache( - key="litellm_jwt_auth_keys", + key="litellm_jwt_auth_keys_my-fake-url", value=keys, )