forked from phoenix/litellm-mirror
feat(proxy_server.py): make team_id optional for jwt token auth (only enforced, if set)
Allows users to use jwt auth for internal chat apps
This commit is contained in:
parent
d9ad7c6218
commit
f48cd87cf3
5 changed files with 89 additions and 54 deletions
|
@ -24,6 +24,7 @@ public_key = {
|
|||
"alg": "RS256",
|
||||
}
|
||||
|
||||
|
||||
def test_load_config_with_custom_role_names():
|
||||
config = {
|
||||
"general_settings": {
|
||||
|
@ -77,7 +78,8 @@ async def test_token_single_public_key():
|
|||
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_invalid_token(audience):
|
||||
"""
|
||||
|
@ -90,7 +92,7 @@ async def test_valid_invalid_token(audience):
|
|||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
|
@ -138,7 +140,7 @@ async def test_valid_invalid_token(audience):
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm-proxy-admin",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -166,7 +168,7 @@ async def test_valid_invalid_token(audience):
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm-NO-SCOPE",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -183,6 +185,7 @@ async def test_valid_invalid_token(audience):
|
|||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prisma_client():
|
||||
import litellm
|
||||
|
@ -205,7 +208,7 @@ def prisma_client():
|
|||
return prisma_client
|
||||
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_token_output(prisma_client, audience):
|
||||
import jwt, json
|
||||
|
@ -222,7 +225,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
|
@ -261,7 +264,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
|
||||
jwt_handler.user_api_key_cache = cache
|
||||
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
|
||||
|
||||
# VALID TOKEN
|
||||
## GENERATE A TOKEN
|
||||
|
@ -274,7 +277,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -289,7 +292,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_proxy_admin",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
@ -315,7 +318,13 @@ async def test_team_token_output(prisma_client, audience):
|
|||
|
||||
## 1. INITIAL TEAM CALL - should fail
|
||||
# use generated key to auth in
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"general_settings",
|
||||
{
|
||||
"enable_jwt_auth": True,
|
||||
},
|
||||
)
|
||||
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||
try:
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
|
@ -358,9 +367,10 @@ async def test_team_token_output(prisma_client, audience):
|
|||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.parametrize("team_id_set", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_token_output(prisma_client, audience):
|
||||
async def test_user_token_output(prisma_client, audience, team_id_set):
|
||||
"""
|
||||
- If user required, check if it exists
|
||||
- fail initial request (when user doesn't exist)
|
||||
|
@ -381,7 +391,7 @@ async def test_user_token_output(prisma_client, audience):
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
|
@ -423,6 +433,8 @@ async def test_user_token_output(prisma_client, audience):
|
|||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||
if team_id_set:
|
||||
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
|
||||
|
||||
# VALID TOKEN
|
||||
## GENERATE A TOKEN
|
||||
|
@ -436,7 +448,7 @@ async def test_user_token_output(prisma_client, audience):
|
|||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -451,7 +463,7 @@ async def test_user_token_output(prisma_client, audience):
|
|||
"sub": user_id,
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_proxy_admin",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
@ -543,7 +555,8 @@ async def test_user_token_output(prisma_client, audience):
|
|||
|
||||
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
||||
|
||||
assert team_result.team_tpm_limit == 100
|
||||
assert team_result.team_rpm_limit == 99
|
||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
if team_id_set:
|
||||
assert team_result.team_tpm_limit == 100
|
||||
assert team_result.team_rpm_limit == 99
|
||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
assert team_result.user_id == user_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue