From 805679becccbee2107b1d0f07a6c811fa8418067 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 6 Mar 2025 23:05:54 -0800 Subject: [PATCH 1/5] feat(handle_jwt.py): support multiple jwt url's --- litellm/proxy/auth/handle_jwt.py | 50 +++++++++++++++------------ tests/proxy_unit_tests/test_jwt.py | 55 +++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 29f4b31f9c..61da9825e6 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -344,32 +344,38 @@ class JWTHandler: if keys_url is None: raise Exception("Missing JWT Public Key URL from environment.") - cached_keys = await self.user_api_key_cache.async_get_cache( - "litellm_jwt_auth_keys" - ) - if cached_keys is None: - response = await self.http_handler.get(keys_url) + keys_url_list = [url.strip() for url in keys_url.split(",")] - response_json = response.json() - if "keys" in response_json: - keys: JWKKeyValue = response.json()["keys"] + for key_url in keys_url_list: + + cache_key = f"litellm_jwt_auth_keys_{key_url}" + + cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) + + if cached_keys is None: + response = await self.http_handler.get(key_url) + + response_json = response.json() + if "keys" in response_json: + keys: JWKKeyValue = response.json()["keys"] + else: + keys = response_json + + await self.user_api_key_cache.async_set_cache( + key=cache_key, + value=keys, + ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins + ) else: - keys = response_json + keys = cached_keys - await self.user_api_key_cache.async_set_cache( - key="litellm_jwt_auth_keys", - value=keys, - ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins - ) - else: - keys = cached_keys + public_key = self.parse_keys(keys=keys, kid=kid) + if public_key is not None: + return cast(dict, public_key) - public_key = self.parse_keys(keys=keys, kid=kid) - if public_key is None: - raise Exception( - f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}" - ) - return cast(dict, public_key) + raise Exception( + f"No matching public key found. keys={keys_url_list}, kid={kid}" + ) def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]: public_key: Optional[JWTKeyItem] = None diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index 7a9d2f0019..d96fb691f7 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -64,7 +64,7 @@ def test_load_config_with_custom_role_names(): @pytest.mark.asyncio -async def test_token_single_public_key(): +async def test_token_single_public_key(monkeypatch): import jwt jwt_handler = JWTHandler() @@ -80,10 +80,15 @@ async def test_token_single_public_key(): ] } + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", + value=backend_keys["keys"], + ) jwt_handler.user_api_key_cache = cache @@ -99,7 +104,7 @@ async def test_token_single_public_key(): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_valid_invalid_token(audience): +async def test_valid_invalid_token(audience, monkeypatch): """ Tests - valid token @@ -116,6 +121,8 @@ async def test_valid_invalid_token(audience): if audience: os.environ["JWT_AUDIENCE"] = audience + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -145,7 +152,9 @@ async def test_valid_invalid_token(audience): # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -294,7 +303,7 @@ def team_token_tuple(): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_team_token_output(prisma_client, audience): +async def test_team_token_output(prisma_client, audience, monkeypatch): import json import uuid @@ -316,6 +325,8 @@ async def test_team_token_output(prisma_client, audience): if audience: os.environ["JWT_AUDIENCE"] = audience + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -345,7 +356,9 @@ async def test_team_token_output(prisma_client, audience): # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -463,7 +476,7 @@ async def test_team_token_output(prisma_client, audience): @pytest.mark.parametrize("user_id_upsert", [True, False]) @pytest.mark.asyncio async def aaaatest_user_token_output( - prisma_client, audience, team_id_set, default_team_id, user_id_upsert + prisma_client, audience, team_id_set, default_team_id, user_id_upsert, monkeypatch ): import uuid @@ -528,10 +541,14 @@ async def aaaatest_user_token_output( assert isinstance(public_jwk, dict) + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -699,7 +716,9 @@ async def aaaatest_user_token_output( @pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]]) @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes): +async def test_allowed_routes_admin( + prisma_client, audience, admin_allowed_routes, monkeypatch +): """ Add a check to make sure jwt proxy admin scope can access all allowed admin routes @@ -723,6 +742,8 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + os.environ.pop("JWT_AUDIENCE", None) if audience: os.environ["JWT_AUDIENCE"] = audience @@ -756,7 +777,9 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -910,7 +933,9 @@ def mock_user_object(*args, **kwargs): "user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)] ) @pytest.mark.asyncio -async def test_allow_access_by_email(public_jwt_key, user_email, should_work): +async def test_allow_access_by_email( + public_jwt_key, user_email, should_work, monkeypatch +): """ Allow anyone with an `@xyz.com` email make a request to the proxy. @@ -925,10 +950,14 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work): public_jwk = public_jwt_key["public_jwk"] private_key = public_jwt_key["private_key"] + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -1074,7 +1103,7 @@ async def test_end_user_jwt_auth(monkeypatch): ] cache.set_cache( - key="litellm_jwt_auth_keys", + key="litellm_jwt_auth_keys_https://example.com/public-key", value=keys, ) From 2c5b2da9558cbda394ff669785dc41ebf89f76d5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sun, 9 Mar 2025 18:35:10 -0700 Subject: [PATCH 2/5] fix: make type object subscriptable --- litellm/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index d1c410e786..7fe5c2fb94 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -618,7 +618,7 @@ class Router: @staticmethod def _create_redis_cache( - cache_config: dict[str, Any] + cache_config: Dict[str, Any] ) -> RedisCache | RedisClusterCache: if cache_config.get("startup_nodes"): return RedisClusterCache(**cache_config) From c08705517bfcca1ad48cb6029a4899f0820ef20c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sun, 9 Mar 2025 19:40:03 -0700 Subject: [PATCH 3/5] test: fix test --- tests/proxy_unit_tests/test_user_api_key_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index dbe49a560d..e956a22282 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -826,7 +826,7 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa ] local_cache.set_cache( - key="litellm_jwt_auth_keys", + key="litellm_jwt_auth_keys_my-fake-url", value=keys, ) From 0b5deb275615abb1e9579ec07054532531b1315e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Mar 2025 18:38:40 -0700 Subject: [PATCH 4/5] fix: fix type --- litellm/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index 7fe5c2fb94..b16a84e11a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -619,7 +619,7 @@ class Router: @staticmethod def _create_redis_cache( cache_config: Dict[str, Any] - ) -> RedisCache | RedisClusterCache: + ) -> Union[RedisCache, RedisClusterCache]: if cache_config.get("startup_nodes"): return RedisClusterCache(**cache_config) else: From f3be2e1fc94560bbc195e2e7b043ddc79e5d1e3c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Mar 2025 20:07:38 -0700 Subject: [PATCH 5/5] fix(transformation.py): fix linting error --- litellm/llms/triton/completion/transformation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index 0a65e216df..4037c32365 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -69,11 +69,13 @@ class TritonConfig(BaseConfig): def get_complete_url( self, - api_base: str, + api_base: Optional[str], model: str, optional_params: dict, stream: Optional[bool] = None, ) -> str: + if api_base is None: + raise ValueError("api_base is required") llm_type = self._get_triton_llm_type(api_base) if llm_type == "generate" and stream: return api_base + "_stream"