diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 0f34cd454..f41a0bdcd 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -40,7 +40,7 @@ general_settings: allow_user_auth: true alerting: ["slack"] # store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True" - # proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) + proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds) enable_jwt_auth: True alerting: ["slack"] litellm_jwtauth: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a322c32c3..97023c64b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -514,6 +514,7 @@ async def user_api_key_auth( team_rpm_limit=team_object.rpm_limit, team_models=team_object.models, user_role="app_owner", + user_id=user_id, ) #### ELSE #### if master_key is None: @@ -1341,8 +1342,6 @@ async def update_database( existing_token_obj = await user_api_key_cache.async_get_cache( key=hashed_token ) - if existing_token_obj is None: - return existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) @@ -1364,7 +1363,9 @@ async def update_database( if end_user_id is not None: prisma_client.end_user_list_transactons[end_user_id] = ( response_cost - + prisma_client.user_list_transactons.get(end_user_id, 0) + + prisma_client.end_user_list_transactons.get( + end_user_id, 0 + ) ) elif custom_db_client is not None: for id in user_ids: diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 0699137cc..407814e84 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -345,3 +345,187 @@ async def test_team_token_output(prisma_client): assert team_result.team_tpm_limit == 100 assert team_result.team_rpm_limit == 99 assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] + + +@pytest.mark.asyncio +async def test_user_token_output(prisma_client): + """ + - If user required, check if it exists + - fail initial request (when user doesn't exist) + - create user + - retry -> it should pass now + """ + import jwt, json + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.backends import default_backend + from fastapi import Request + from starlette.datastructures import URL + from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest + import litellm + import uuid + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + await litellm.proxy.proxy_server.prisma_client.connect() + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + assert isinstance(public_jwk, dict) + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() + + jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + team_id = f"team123_{uuid.uuid4()}" + user_id = f"user123_{uuid.uuid4()}" + payload = { + "sub": user_id, + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_team", + "client_id": team_id, + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + + ## team token + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## admin token + payload = { + "sub": user_id, + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_proxy_admin", + } + + admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + + # verify token + + response = await jwt_handler.auth_jwt(token=token) + + ## RUN IT THROUGH USER API KEY AUTH + + """ + - 1. Initial call should fail -> team doesn't exist + - 2. Create team via admin token + - 3. 2nd call w/ same team -> call should fail -> user doesn't exist + - 4. Create user via admin token + - 5. 3rd call w/ same team, same user -> call should succeed + - 6. assert user api key auth format + """ + + bearer_token = "Bearer " + token + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + ## 1. INITIAL TEAM CALL - should fail + # use generated key to auth in + setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}) + setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + try: + result = await user_api_key_auth(request=request, api_key=bearer_token) + pytest.fail("Team doesn't exist. This should fail") + except Exception as e: + pass + + ## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed + try: + bearer_token = "Bearer " + admin_token + + request._url = URL(url="/team/new") + result = await user_api_key_auth(request=request, api_key=bearer_token) + await new_team( + data=NewTeamRequest( + team_id=team_id, + tpm_limit=100, + rpm_limit=99, + models=["gpt-3.5-turbo", "gpt-4"], + ), + user_api_key_dict=result, + ) + except Exception as e: + pytest.fail(f"This should not fail - {str(e)}") + + ## 3. 2nd CALL W/ TEAM TOKEN - should fail + bearer_token = "Bearer " + token + request._url = URL(url="/chat/completions") + try: + team_result: UserAPIKeyAuth = await user_api_key_auth( + request=request, api_key=bearer_token + ) + pytest.fail(f"User doesn't exist. this should fail") + except Exception as e: + pass + + ## 4. Create user + try: + bearer_token = "Bearer " + admin_token + + request._url = URL(url="/team/new") + result = await user_api_key_auth(request=request, api_key=bearer_token) + await new_user( + data=NewUserRequest( + user_id=user_id, + ), + ) + except Exception as e: + pytest.fail(f"This should not fail - {str(e)}") + + ## 5. 3rd call w/ same team, same user -> call should succeed + bearer_token = "Bearer " + token + request._url = URL(url="/chat/completions") + try: + team_result: UserAPIKeyAuth = await user_api_key_auth( + request=request, api_key=bearer_token + ) + except Exception as e: + pytest.fail(f"Team exists. This should not fail - {e}") + + ## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking) + + assert team_result.team_tpm_limit == 100 + assert team_result.team_rpm_limit == 99 + assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] + assert team_result.user_id == user_id