mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Merge pull request #9047 from BerriAI/litellm_dev_03_06_2025_p4
feat(handle_jwt.py): support multiple jwt url's
This commit is contained in:
commit
c93a5e2301
3 changed files with 71 additions and 36 deletions
|
@ -345,32 +345,38 @@ class JWTHandler:
|
|||
if keys_url is None:
|
||||
raise Exception("Missing JWT Public Key URL from environment.")
|
||||
|
||||
cached_keys = await self.user_api_key_cache.async_get_cache(
|
||||
"litellm_jwt_auth_keys"
|
||||
)
|
||||
if cached_keys is None:
|
||||
response = await self.http_handler.get(keys_url)
|
||||
keys_url_list = [url.strip() for url in keys_url.split(",")]
|
||||
|
||||
response_json = response.json()
|
||||
if "keys" in response_json:
|
||||
keys: JWKKeyValue = response.json()["keys"]
|
||||
for key_url in keys_url_list:
|
||||
|
||||
cache_key = f"litellm_jwt_auth_keys_{key_url}"
|
||||
|
||||
cached_keys = await self.user_api_key_cache.async_get_cache(cache_key)
|
||||
|
||||
if cached_keys is None:
|
||||
response = await self.http_handler.get(key_url)
|
||||
|
||||
response_json = response.json()
|
||||
if "keys" in response_json:
|
||||
keys: JWKKeyValue = response.json()["keys"]
|
||||
else:
|
||||
keys = response_json
|
||||
|
||||
await self.user_api_key_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=keys,
|
||||
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
|
||||
)
|
||||
else:
|
||||
keys = response_json
|
||||
keys = cached_keys
|
||||
|
||||
await self.user_api_key_cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys",
|
||||
value=keys,
|
||||
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
|
||||
)
|
||||
else:
|
||||
keys = cached_keys
|
||||
public_key = self.parse_keys(keys=keys, kid=kid)
|
||||
if public_key is not None:
|
||||
return cast(dict, public_key)
|
||||
|
||||
public_key = self.parse_keys(keys=keys, kid=kid)
|
||||
if public_key is None:
|
||||
raise Exception(
|
||||
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}"
|
||||
)
|
||||
return cast(dict, public_key)
|
||||
raise Exception(
|
||||
f"No matching public key found. keys={keys_url_list}, kid={kid}"
|
||||
)
|
||||
|
||||
def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]:
|
||||
public_key: Optional[JWTKeyItem] = None
|
||||
|
|
|
@ -64,7 +64,7 @@ def test_load_config_with_custom_role_names():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_single_public_key():
|
||||
async def test_token_single_public_key(monkeypatch):
|
||||
import jwt
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
@ -80,10 +80,15 @@ async def test_token_single_public_key():
|
|||
]
|
||||
}
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key",
|
||||
value=backend_keys["keys"],
|
||||
)
|
||||
|
||||
jwt_handler.user_api_key_cache = cache
|
||||
|
||||
|
@ -99,7 +104,7 @@ async def test_token_single_public_key():
|
|||
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_invalid_token(audience):
|
||||
async def test_valid_invalid_token(audience, monkeypatch):
|
||||
"""
|
||||
Tests
|
||||
- valid token
|
||||
|
@ -116,6 +121,8 @@ async def test_valid_invalid_token(audience):
|
|||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
|
@ -145,7 +152,9 @@ async def test_valid_invalid_token(audience):
|
|||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -294,7 +303,7 @@ def team_token_tuple():
|
|||
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_token_output(prisma_client, audience):
|
||||
async def test_team_token_output(prisma_client, audience, monkeypatch):
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
@ -316,6 +325,8 @@ async def test_team_token_output(prisma_client, audience):
|
|||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
|
@ -345,7 +356,9 @@ async def test_team_token_output(prisma_client, audience):
|
|||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -463,7 +476,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
@pytest.mark.parametrize("user_id_upsert", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def aaaatest_user_token_output(
|
||||
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
|
||||
prisma_client, audience, team_id_set, default_team_id, user_id_upsert, monkeypatch
|
||||
):
|
||||
import uuid
|
||||
|
||||
|
@ -528,10 +541,14 @@ async def aaaatest_user_token_output(
|
|||
|
||||
assert isinstance(public_jwk, dict)
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -699,7 +716,9 @@ async def aaaatest_user_token_output(
|
|||
@pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]])
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes):
|
||||
async def test_allowed_routes_admin(
|
||||
prisma_client, audience, admin_allowed_routes, monkeypatch
|
||||
):
|
||||
"""
|
||||
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
|
||||
|
||||
|
@ -723,6 +742,8 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
@ -756,7 +777,9 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route
|
|||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -910,7 +933,9 @@ def mock_user_object(*args, **kwargs):
|
|||
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
|
||||
async def test_allow_access_by_email(
|
||||
public_jwt_key, user_email, should_work, monkeypatch
|
||||
):
|
||||
"""
|
||||
Allow anyone with an `@xyz.com` email make a request to the proxy.
|
||||
|
||||
|
@ -925,10 +950,14 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
|
|||
public_jwk = public_jwt_key["public_jwk"]
|
||||
private_key = public_jwt_key["private_key"]
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -1074,7 +1103,7 @@ async def test_end_user_jwt_auth(monkeypatch):
|
|||
]
|
||||
|
||||
cache.set_cache(
|
||||
key="litellm_jwt_auth_keys",
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key",
|
||||
value=keys,
|
||||
)
|
||||
|
||||
|
|
|
@ -826,7 +826,7 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
|
|||
]
|
||||
|
||||
local_cache.set_cache(
|
||||
key="litellm_jwt_auth_keys",
|
||||
key="litellm_jwt_auth_keys_my-fake-url",
|
||||
value=keys,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue