diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 407814e84..3dacebe27 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -24,7 +24,6 @@ public_key = { "alg": "RS256", } - def test_load_config_with_custom_role_names(): config = { "general_settings": { @@ -136,6 +135,8 @@ async def test_valid_invalid_token(): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-proxy-admin", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -163,6 +164,8 @@ async def test_valid_invalid_token(): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm-NO-SCOPE", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -266,6 +269,8 @@ async def test_team_token_output(prisma_client): "scope": "litellm_team", "client_id": team_id, } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -280,6 +285,8 @@ async def test_team_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") @@ -421,6 +428,8 @@ async def test_user_token_output(prisma_client): "scope": "litellm_team", "client_id": team_id, } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") # Generate the JWT token # But before, you should convert bytes to string @@ -435,6 +444,8 @@ async def test_user_token_output(prisma_client): "exp": expiration_time, # set the token to expire in 10 minutes "scope": "litellm_proxy_admin", } + if os.getenv("JWT_AUDIENCE"): + payload["aud"] = os.getenv("JWT_AUDIENCE") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")