forked from phoenix/litellm-mirror
test(test_jwt.py): add testing to make sure user api key auth returns the expected values
This commit is contained in:
parent
67b5634417
commit
77097f8e7d
3 changed files with 189 additions and 4 deletions
|
@ -40,7 +40,7 @@ general_settings:
|
|||
allow_user_auth: true
|
||||
alerting: ["slack"]
|
||||
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
||||
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
enable_jwt_auth: True
|
||||
alerting: ["slack"]
|
||||
litellm_jwtauth:
|
||||
|
|
|
@ -514,6 +514,7 @@ async def user_api_key_auth(
|
|||
team_rpm_limit=team_object.rpm_limit,
|
||||
team_models=team_object.models,
|
||||
user_role="app_owner",
|
||||
user_id=user_id,
|
||||
)
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
|
@ -1341,8 +1342,6 @@ async def update_database(
|
|||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||
key=hashed_token
|
||||
)
|
||||
if existing_token_obj is None:
|
||||
return
|
||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
||||
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
||||
|
@ -1364,7 +1363,9 @@ async def update_database(
|
|||
if end_user_id is not None:
|
||||
prisma_client.end_user_list_transactons[end_user_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(end_user_id, 0)
|
||||
+ prisma_client.end_user_list_transactons.get(
|
||||
end_user_id, 0
|
||||
)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
for id in user_ids:
|
||||
|
|
|
@ -345,3 +345,187 @@ async def test_team_token_output(prisma_client):
|
|||
assert team_result.team_tpm_limit == 100
|
||||
assert team_result.team_rpm_limit == 99
|
||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_token_output(prisma_client):
|
||||
"""
|
||||
- If user required, check if it exists
|
||||
- fail initial request (when user doesn't exist)
|
||||
- create user
|
||||
- retry -> it should pass now
|
||||
"""
|
||||
import jwt, json
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
|
||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
||||
import litellm
|
||||
import uuid
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
)
|
||||
# Get private key in PEM format
|
||||
private_key = key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
# Get public key in PEM format
|
||||
public_key = key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
public_key_obj = serialization.load_pem_public_key(
|
||||
public_key, backend=default_backend()
|
||||
)
|
||||
|
||||
# Convert RSA public key object to JWK (JSON Web Key)
|
||||
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||
|
||||
assert isinstance(public_jwk, dict)
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
jwt_handler.user_api_key_cache = cache
|
||||
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||
|
||||
# VALID TOKEN
|
||||
## GENERATE A TOKEN
|
||||
# Assuming the current time is in UTC
|
||||
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||
|
||||
team_id = f"team123_{uuid.uuid4()}"
|
||||
user_id = f"user123_{uuid.uuid4()}"
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
# But before, you should convert bytes to string
|
||||
private_key_str = private_key.decode("utf-8")
|
||||
|
||||
## team token
|
||||
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
||||
## admin token
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_proxy_admin",
|
||||
}
|
||||
|
||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
||||
## VERIFY IT WORKS
|
||||
|
||||
# verify token
|
||||
|
||||
response = await jwt_handler.auth_jwt(token=token)
|
||||
|
||||
## RUN IT THROUGH USER API KEY AUTH
|
||||
|
||||
"""
|
||||
- 1. Initial call should fail -> team doesn't exist
|
||||
- 2. Create team via admin token
|
||||
- 3. 2nd call w/ same team -> call should fail -> user doesn't exist
|
||||
- 4. Create user via admin token
|
||||
- 5. 3rd call w/ same team, same user -> call should succeed
|
||||
- 6. assert user api key auth format
|
||||
"""
|
||||
|
||||
bearer_token = "Bearer " + token
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
## 1. INITIAL TEAM CALL - should fail
|
||||
# use generated key to auth in
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
|
||||
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||
try:
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
pytest.fail("Team doesn't exist. This should fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
|
||||
try:
|
||||
bearer_token = "Bearer " + admin_token
|
||||
|
||||
request._url = URL(url="/team/new")
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
await new_team(
|
||||
data=NewTeamRequest(
|
||||
team_id=team_id,
|
||||
tpm_limit=100,
|
||||
rpm_limit=99,
|
||||
models=["gpt-3.5-turbo", "gpt-4"],
|
||||
),
|
||||
user_api_key_dict=result,
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"This should not fail - {str(e)}")
|
||||
|
||||
## 3. 2nd CALL W/ TEAM TOKEN - should fail
|
||||
bearer_token = "Bearer " + token
|
||||
request._url = URL(url="/chat/completions")
|
||||
try:
|
||||
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||
request=request, api_key=bearer_token
|
||||
)
|
||||
pytest.fail(f"User doesn't exist. this should fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
## 4. Create user
|
||||
try:
|
||||
bearer_token = "Bearer " + admin_token
|
||||
|
||||
request._url = URL(url="/team/new")
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
await new_user(
|
||||
data=NewUserRequest(
|
||||
user_id=user_id,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"This should not fail - {str(e)}")
|
||||
|
||||
## 5. 3rd call w/ same team, same user -> call should succeed
|
||||
bearer_token = "Bearer " + token
|
||||
request._url = URL(url="/chat/completions")
|
||||
try:
|
||||
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||
request=request, api_key=bearer_token
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Team exists. This should not fail - {e}")
|
||||
|
||||
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
||||
|
||||
assert team_result.team_tpm_limit == 100
|
||||
assert team_result.team_rpm_limit == 99
|
||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
assert team_result.user_id == user_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue