From 8e1157fc926b5619905d2608eb49ce8df3d32616 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 24 Jan 2024 21:08:09 -0800 Subject: [PATCH] test(test_keys.py): reset proxy spend --- litellm/proxy/_types.py | 32 +++++++++++++++++++------------- litellm/proxy/proxy_server.py | 35 +++++++++++++++++++++++++++++++++-- litellm/proxy/utils.py | 3 ++- tests/test_keys.py | 23 +++++++++++++++++++++++ 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 86400b7e2..e64fc729e 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -122,11 +122,12 @@ class ModelParams(LiteLLMBase): return values -class GenerateKeyRequest(LiteLLMBase): - duration: Optional[str] = "1h" +class GenerateRequestBase(LiteLLMBase): + """ + Overlapping schema between key and user generate/update requests + """ + models: Optional[list] = [] - aliases: Optional[dict] = {} - config: Optional[dict] = {} spend: Optional[float] = 0 max_budget: Optional[float] = None user_id: Optional[str] = None @@ -138,21 +139,18 @@ class GenerateKeyRequest(LiteLLMBase): budget_duration: Optional[str] = None -class UpdateKeyRequest(LiteLLMBase): +class GenerateKeyRequest(GenerateRequestBase): + duration: Optional[str] = "1h" + aliases: Optional[dict] = {} + + +class UpdateKeyRequest(GenerateKeyRequest): # Note: the defaults of all Params here MUST BE NONE # else they will get overwritten key: str duration: Optional[str] = None - models: Optional[list] = None - aliases: Optional[dict] = None - config: Optional[dict] = None spend: Optional[float] = None - max_budget: Optional[float] = None - user_id: Optional[str] = None - max_parallel_requests: Optional[int] = None metadata: Optional[dict] = None - tpm_limit: Optional[int] = None - rpm_limit: Optional[int] = None class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth @@ -192,6 +190,14 @@ class NewUserResponse(GenerateKeyResponse): max_budget: Optional[float] = None +class UpdateUserRequest(GenerateRequestBase): + # Note: the defaults of all Params here MUST BE NONE + # else they will get overwritten + user_id: str + spend: Optional[float] = None + metadata: Optional[dict] = None + + class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d5b91a392..027e29f94 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2799,11 +2799,42 @@ async def user_info( @router.post( "/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)] ) -async def user_update(request: Request): +async def user_update(data: UpdateUserRequest): """ [TODO]: Use this to update user budget """ - pass + global prisma_client + try: + data_json: dict = data.json() + # get the row from db + if prisma_client is None: + raise Exception("Not connected to DB!") + + non_default_values = {k: v for k, v in data_json.items() if v is not None} + response = await prisma_client.update_data( + user_id=data_json["user_id"], + data=non_default_values, + update_key_values=non_default_values, + ) + return {"user_id": data_json["user_id"], **non_default_values} + # update based on remaining passed in values + except Exception as e: + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) #### MODEL MANAGEMENT #### diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0debcb235..789125b82 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -655,13 +655,14 @@ class PrismaClient: user_id is not None or (table_name is not None and table_name == "user") and query_type == "update" - and update_key_values is not None ): """ If data['spend'] + data['user'], update the user table with spend info as well """ if user_id is None: user_id = db_data["user_id"] + if update_key_values is None: + update_key_values = db_data update_user_row = await self.db.litellm_usertable.upsert( where={"user_id": user_id}, # type: ignore data={ diff --git a/tests/test_keys.py b/tests/test_keys.py index e9ec39256..f05204c03 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -67,6 +67,28 @@ async def update_key(session, get_key): return await response.json() +async def update_proxy_budget(session): + """ + Make sure only models user has access to are returned + """ + url = "http://0.0.0.0:4000/user/update" + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + } + data = {"user_id": "litellm-proxy-budget", "spend": 0} + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + async def chat_completion(session, key, model="gpt-4"): url = "http://0.0.0.0:4000/chat/completions" headers = { @@ -135,6 +157,7 @@ async def test_key_update(): session=session, get_key=key, ) + await update_proxy_budget(session=session) # resets proxy spend await chat_completion(session=session, key=key)