From 2b79e44fc5ad9cd1b940362bb485a02197162fba Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 12 Aug 2024 12:11:23 -0700 Subject: [PATCH] fix internal user tests to pass --- .../internal_user_endpoints.py | 2 + litellm/tests/test_key_generate_prisma.py | 132 ++++++++++++++++-- 2 files changed, 119 insertions(+), 15 deletions(-) diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index bced1851ee..8e2358c992 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -47,6 +47,7 @@ router = APIRouter() @management_endpoint_wrapper async def new_user( data: NewUserRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ Use this to create a new INTERNAL user with a budget. @@ -461,6 +462,7 @@ async def user_info( @management_endpoint_wrapper async def user_update( data: UpdateUserRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ Example curl diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 344c9691b1..a757476a55 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -171,6 +171,11 @@ async def test_new_user_response(prisma_client): models=["azure-gpt-3.5"], team_id=_team_id, tpm_limit=20, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), ) ) print(_response) @@ -236,7 +241,14 @@ def test_generate_and_call_with_valid_key(prisma_client, api_route): from litellm.proxy.proxy_server import user_api_key_cache request = NewUserRequest(user_role=LitellmUserRoles.INTERNAL_USER) - key = await new_user(request) + key = await new_user( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) user_id = key.user_id @@ -312,7 +324,14 @@ def test_call_with_invalid_model(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest(models=["mistral"]) - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -353,7 +372,14 @@ def test_call_with_valid_model(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest(models=["mistral"]) - key = await new_user(request) + key = await new_user( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -445,7 +471,14 @@ def test_call_with_user_over_budget(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest(max_budget=0.00001) - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -640,7 +673,14 @@ def test_call_with_proxy_over_budget(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest() - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -724,7 +764,14 @@ def test_call_with_user_over_budget_stream(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest(max_budget=0.00001) - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -822,9 +869,15 @@ def test_call_with_proxy_over_budget_stream(prisma_client): # request = NewUserRequest( # max_budget=0.00001, user_id=litellm_proxy_budget_name # ) - # await new_user(request) request = NewUserRequest() - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -900,7 +953,14 @@ def test_generate_and_call_with_valid_key_never_expires(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest(duration=None) - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -930,7 +990,14 @@ def test_generate_and_call_with_expired_key(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest(duration="0s") - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -969,7 +1036,14 @@ def test_delete_key(prisma_client): from litellm.proxy.proxy_server import user_api_key_cache request = NewUserRequest() - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1018,7 +1092,14 @@ def test_delete_key_auth(prisma_client): from litellm.proxy.proxy_server import user_api_key_cache request = NewUserRequest() - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1081,7 +1162,14 @@ def test_generate_and_call_key_info(prisma_client): request = NewUserRequest( metadata={"team": "litellm-team3", "project": "litellm-project3"} ) - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1166,7 +1254,14 @@ def test_generate_and_update_key(prisma_client): team_id=_team_1, ) - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -2466,7 +2561,14 @@ async def test_enforced_params(prisma_client): await litellm.proxy.proxy_server.prisma_client.connect() request = NewUserRequest() - key = await new_user(request) + key = await new_user( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key