diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 4cf51d3de..cb04f32a5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -121,11 +121,24 @@ class GenerateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None +class UpdateKeyRequest(LiteLLMBase): + key: str + duration: Optional[str] = None + models: Optional[list] = None + aliases: Optional[dict] = None + config: Optional[dict] = None + spend: Optional[float] = None + user_id: Optional[str] = None + max_parallel_requests: Optional[int] = None + class GenerateKeyResponse(LiteLLMBase): key: str expires: datetime user_id: str + + + class _DeleteKeyObject(LiteLLMBase): key: str diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7f269a835..af7bd3b4a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -252,15 +252,15 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap if api_key is None: # only require api key if master key is set raise Exception(f"No api key passed in.") - route = request.url.path + route: str = request.url.path # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead is_master_key_valid = secrets.compare_digest(api_key, master_key) if is_master_key_valid: return UserAPIKeyAuth(api_key=master_key) - if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: - raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") + if route.startswith("/key/") and not is_master_key_valid: + raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys") if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error raise Exception("No connected db.") @@ -676,6 +676,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return {"token": token, "expires": new_verification_token.expires, "user_id": user_id} + + async def delete_verification_token(tokens: List): global prisma_client try: @@ -1140,6 +1142,30 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat response = await generate_key_helper_fn(**data_json) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) +@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) +async def update_key_fn(request: Request, data: UpdateKeyRequest): + """ + Update an existing key + """ + global prisma_client + try: + data_json: dict = data.json() + key = data_json.pop("key") + # get the row from db + if prisma_client is None: + raise Exception("Not connected to DB!") + + non_default_values = {k: v for k, v in data_json.items() if v is not None} + print(f"non_default_values: {non_default_values}") + response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key}) + return {"key": key, **non_default_values} + # update based on remaining passed in values + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) async def delete_key_fn(request: Request, data: DeleteKeyRequest): try: diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index ace3c5527..14b239ae1 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -71,6 +71,38 @@ def test_add_new_key(client): except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") + +def test_update_new_key(client): + try: + # Your test data + test_data = { + "models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": "20m" + } + print("testing proxy server") + # Your bearer token + token = os.getenv("PROXY_MASTER_KEY") + + headers = { + "Authorization": f"Bearer {token}" + } + response = client.post("/key/generate", json=test_data, headers=headers) + print(f"response: {response.text}") + assert response.status_code == 200 + result = response.json() + assert result["key"].startswith("sk-") + def _post_data(): + json_data = {'models': ['bedrock-models'], "key": result["key"]} + response = client.post("/key/update", json=json_data, headers=headers) + print(f"response text: {response.text}") + assert response.status_code == 200 + return response + _post_data() + print(f"Received response: {result}") + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") + # # Run the test - only runs via pytest