feat(proxy_server.py): make team_id optional for jwt token auth (only enforced, if set)

Allows users to use jwt auth for internal chat apps
This commit is contained in:
Krrish Dholakia 2024-05-15 21:05:14 -07:00
parent d9ad7c6218
commit f48cd87cf3
5 changed files with 89 additions and 54 deletions

View file

@ -24,6 +24,7 @@ public_key = {
"alg": "RS256",
}
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
@ -77,7 +78,8 @@ async def test_token_single_public_key():
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
)
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_valid_invalid_token(audience):
"""
@ -90,7 +92,7 @@ async def test_valid_invalid_token(audience):
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
os.environ.pop('JWT_AUDIENCE', None)
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
@ -138,7 +140,7 @@ async def test_valid_invalid_token(audience):
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin",
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -166,7 +168,7 @@ async def test_valid_invalid_token(audience):
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE",
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -183,6 +185,7 @@ async def test_valid_invalid_token(audience):
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.fixture
def prisma_client():
import litellm
@ -205,7 +208,7 @@ def prisma_client():
return prisma_client
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience):
import jwt, json
@ -222,7 +225,7 @@ async def test_team_token_output(prisma_client, audience):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None)
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
@ -261,7 +264,7 @@ async def test_team_token_output(prisma_client, audience):
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN
## GENERATE A TOKEN
@ -274,7 +277,7 @@ async def test_team_token_output(prisma_client, audience):
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -289,7 +292,7 @@ async def test_team_token_output(prisma_client, audience):
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -315,7 +318,13 @@ async def test_team_token_output(prisma_client, audience):
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
@ -358,9 +367,10 @@ async def test_team_token_output(prisma_client, audience):
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.parametrize("team_id_set", [True, False])
@pytest.mark.asyncio
async def test_user_token_output(prisma_client, audience):
async def test_user_token_output(prisma_client, audience, team_id_set):
"""
- If user required, check if it exists
- fail initial request (when user doesn't exist)
@ -381,7 +391,7 @@ async def test_user_token_output(prisma_client, audience):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None)
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
@ -423,6 +433,8 @@ async def test_user_token_output(prisma_client, audience):
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
if team_id_set:
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
# VALID TOKEN
## GENERATE A TOKEN
@ -436,7 +448,7 @@ async def test_user_token_output(prisma_client, audience):
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -451,7 +463,7 @@ async def test_user_token_output(prisma_client, audience):
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -543,7 +555,8 @@ async def test_user_token_output(prisma_client, audience):
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
if team_id_set:
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
assert team_result.user_id == user_id