updated tests to also check for audience if found

This commit is contained in:
Sara Ghaemi 2024-05-07 12:10:47 -04:00
parent 66b2b5fab9
commit 7017899d37

View file

@ -24,7 +24,6 @@ public_key = {
"alg": "RS256", "alg": "RS256",
} }
def test_load_config_with_custom_role_names(): def test_load_config_with_custom_role_names():
config = { config = {
"general_settings": { "general_settings": {
@ -136,6 +135,8 @@ async def test_valid_invalid_token():
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin", "scope": "litellm-proxy-admin",
} }
if os.getenv("JWT_AUDIENCE"):
payload["aud"] = os.getenv("JWT_AUDIENCE")
# Generate the JWT token # Generate the JWT token
# But before, you should convert bytes to string # But before, you should convert bytes to string
@ -163,6 +164,8 @@ async def test_valid_invalid_token():
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE", "scope": "litellm-NO-SCOPE",
} }
if os.getenv("JWT_AUDIENCE"):
payload["aud"] = os.getenv("JWT_AUDIENCE")
# Generate the JWT token # Generate the JWT token
# But before, you should convert bytes to string # But before, you should convert bytes to string
@ -266,6 +269,8 @@ async def test_team_token_output(prisma_client):
"scope": "litellm_team", "scope": "litellm_team",
"client_id": team_id, "client_id": team_id,
} }
if os.getenv("JWT_AUDIENCE"):
payload["aud"] = os.getenv("JWT_AUDIENCE")
# Generate the JWT token # Generate the JWT token
# But before, you should convert bytes to string # But before, you should convert bytes to string
@ -280,6 +285,8 @@ async def test_team_token_output(prisma_client):
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin", "scope": "litellm_proxy_admin",
} }
if os.getenv("JWT_AUDIENCE"):
payload["aud"] = os.getenv("JWT_AUDIENCE")
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -421,6 +428,8 @@ async def test_user_token_output(prisma_client):
"scope": "litellm_team", "scope": "litellm_team",
"client_id": team_id, "client_id": team_id,
} }
if os.getenv("JWT_AUDIENCE"):
payload["aud"] = os.getenv("JWT_AUDIENCE")
# Generate the JWT token # Generate the JWT token
# But before, you should convert bytes to string # But before, you should convert bytes to string
@ -435,6 +444,8 @@ async def test_user_token_output(prisma_client):
"exp": expiration_time, # set the token to expire in 10 minutes "exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin", "scope": "litellm_proxy_admin",
} }
if os.getenv("JWT_AUDIENCE"):
payload["aud"] = os.getenv("JWT_AUDIENCE")
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")