Merge pull request #9047 from BerriAI/litellm_dev_03_06_2025_p4

feat(handle_jwt.py): support multiple jwt url's
This commit is contained in:
Krish Dholakia 2025-03-10 22:37:35 -07:00 committed by GitHub
commit c93a5e2301
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 71 additions and 36 deletions

View file

@ -345,32 +345,38 @@ class JWTHandler:
if keys_url is None:
raise Exception("Missing JWT Public Key URL from environment.")
cached_keys = await self.user_api_key_cache.async_get_cache(
"litellm_jwt_auth_keys"
)
if cached_keys is None:
response = await self.http_handler.get(keys_url)
keys_url_list = [url.strip() for url in keys_url.split(",")]
response_json = response.json()
if "keys" in response_json:
keys: JWKKeyValue = response.json()["keys"]
for key_url in keys_url_list:
cache_key = f"litellm_jwt_auth_keys_{key_url}"
cached_keys = await self.user_api_key_cache.async_get_cache(cache_key)
if cached_keys is None:
response = await self.http_handler.get(key_url)
response_json = response.json()
if "keys" in response_json:
keys: JWKKeyValue = response.json()["keys"]
else:
keys = response_json
await self.user_api_key_cache.async_set_cache(
key=cache_key,
value=keys,
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
)
else:
keys = response_json
keys = cached_keys
await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys",
value=keys,
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
)
else:
keys = cached_keys
public_key = self.parse_keys(keys=keys, kid=kid)
if public_key is not None:
return cast(dict, public_key)
public_key = self.parse_keys(keys=keys, kid=kid)
if public_key is None:
raise Exception(
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}"
)
return cast(dict, public_key)
raise Exception(
f"No matching public key found. keys={keys_url_list}, kid={kid}"
)
def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]:
public_key: Optional[JWTKeyItem] = None

View file

@ -64,7 +64,7 @@ def test_load_config_with_custom_role_names():
@pytest.mark.asyncio
async def test_token_single_public_key():
async def test_token_single_public_key(monkeypatch):
import jwt
jwt_handler = JWTHandler()
@ -80,10 +80,15 @@ async def test_token_single_public_key():
]
}
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
await cache.async_set_cache(
key="litellm_jwt_auth_keys_https://example.com/public-key",
value=backend_keys["keys"],
)
jwt_handler.user_api_key_cache = cache
@ -99,7 +104,7 @@ async def test_token_single_public_key():
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_valid_invalid_token(audience):
async def test_valid_invalid_token(audience, monkeypatch):
"""
Tests
- valid token
@ -116,6 +121,8 @@ async def test_valid_invalid_token(audience):
if audience:
os.environ["JWT_AUDIENCE"] = audience
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
@ -145,7 +152,9 @@ async def test_valid_invalid_token(audience):
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
await cache.async_set_cache(
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
)
jwt_handler = JWTHandler()
@ -294,7 +303,7 @@ def team_token_tuple():
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience):
async def test_team_token_output(prisma_client, audience, monkeypatch):
import json
import uuid
@ -316,6 +325,8 @@ async def test_team_token_output(prisma_client, audience):
if audience:
os.environ["JWT_AUDIENCE"] = audience
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
@ -345,7 +356,9 @@ async def test_team_token_output(prisma_client, audience):
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
await cache.async_set_cache(
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
)
jwt_handler = JWTHandler()
@ -463,7 +476,7 @@ async def test_team_token_output(prisma_client, audience):
@pytest.mark.parametrize("user_id_upsert", [True, False])
@pytest.mark.asyncio
async def aaaatest_user_token_output(
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
prisma_client, audience, team_id_set, default_team_id, user_id_upsert, monkeypatch
):
import uuid
@ -528,10 +541,14 @@ async def aaaatest_user_token_output(
assert isinstance(public_jwk, dict)
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
await cache.async_set_cache(
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
)
jwt_handler = JWTHandler()
@ -699,7 +716,9 @@ async def aaaatest_user_token_output(
@pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes):
async def test_allowed_routes_admin(
prisma_client, audience, admin_allowed_routes, monkeypatch
):
"""
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
@ -723,6 +742,8 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
@ -756,7 +777,9 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
await cache.async_set_cache(
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
)
jwt_handler = JWTHandler()
@ -910,7 +933,9 @@ def mock_user_object(*args, **kwargs):
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
)
@pytest.mark.asyncio
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
async def test_allow_access_by_email(
public_jwt_key, user_email, should_work, monkeypatch
):
"""
Allow anyone with an `@xyz.com` email make a request to the proxy.
@ -925,10 +950,14 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
public_jwk = public_jwt_key["public_jwk"]
private_key = public_jwt_key["private_key"]
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
await cache.async_set_cache(
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
)
jwt_handler = JWTHandler()
@ -1074,7 +1103,7 @@ async def test_end_user_jwt_auth(monkeypatch):
]
cache.set_cache(
key="litellm_jwt_auth_keys",
key="litellm_jwt_auth_keys_https://example.com/public-key",
value=keys,
)

View file

@ -826,7 +826,7 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
]
local_cache.set_cache(
key="litellm_jwt_auth_keys",
key="litellm_jwt_auth_keys_my-fake-url",
value=keys,
)