mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat(handle_jwt.py): support multiple jwt url's
This commit is contained in:
parent
5b2eb1f6bb
commit
5280a914cd
2 changed files with 70 additions and 35 deletions
|
@ -64,7 +64,7 @@ def test_load_config_with_custom_role_names():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_single_public_key():
|
||||
async def test_token_single_public_key(monkeypatch):
|
||||
import jwt
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
@ -80,10 +80,15 @@ async def test_token_single_public_key():
|
|||
]
|
||||
}
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key",
|
||||
value=backend_keys["keys"],
|
||||
)
|
||||
|
||||
jwt_handler.user_api_key_cache = cache
|
||||
|
||||
|
@ -99,7 +104,7 @@ async def test_token_single_public_key():
|
|||
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_invalid_token(audience):
|
||||
async def test_valid_invalid_token(audience, monkeypatch):
|
||||
"""
|
||||
Tests
|
||||
- valid token
|
||||
|
@ -116,6 +121,8 @@ async def test_valid_invalid_token(audience):
|
|||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
|
@ -145,7 +152,9 @@ async def test_valid_invalid_token(audience):
|
|||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -294,7 +303,7 @@ def team_token_tuple():
|
|||
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_token_output(prisma_client, audience):
|
||||
async def test_team_token_output(prisma_client, audience, monkeypatch):
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
@ -316,6 +325,8 @@ async def test_team_token_output(prisma_client, audience):
|
|||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
|
@ -345,7 +356,9 @@ async def test_team_token_output(prisma_client, audience):
|
|||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -463,7 +476,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
@pytest.mark.parametrize("user_id_upsert", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def aaaatest_user_token_output(
|
||||
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
|
||||
prisma_client, audience, team_id_set, default_team_id, user_id_upsert, monkeypatch
|
||||
):
|
||||
import uuid
|
||||
|
||||
|
@ -528,10 +541,14 @@ async def aaaatest_user_token_output(
|
|||
|
||||
assert isinstance(public_jwk, dict)
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -699,7 +716,9 @@ async def aaaatest_user_token_output(
|
|||
@pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]])
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes):
|
||||
async def test_allowed_routes_admin(
|
||||
prisma_client, audience, admin_allowed_routes, monkeypatch
|
||||
):
|
||||
"""
|
||||
Add a check to make sure jwt proxy admin scope can access all allowed admin routes
|
||||
|
||||
|
@ -723,6 +742,8 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
@ -756,7 +777,9 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route
|
|||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -910,7 +933,9 @@ def mock_user_object(*args, **kwargs):
|
|||
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
|
||||
async def test_allow_access_by_email(
|
||||
public_jwt_key, user_email, should_work, monkeypatch
|
||||
):
|
||||
"""
|
||||
Allow anyone with an `@xyz.com` email make a request to the proxy.
|
||||
|
||||
|
@ -925,10 +950,14 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
|
|||
public_jwk = public_jwt_key["public_jwk"]
|
||||
private_key = public_jwt_key["private_key"]
|
||||
|
||||
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key")
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
await cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk]
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
|
@ -1074,7 +1103,7 @@ async def test_end_user_jwt_auth(monkeypatch):
|
|||
]
|
||||
|
||||
cache.set_cache(
|
||||
key="litellm_jwt_auth_keys",
|
||||
key="litellm_jwt_auth_keys_https://example.com/public-key",
|
||||
value=keys,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue