From 70716b3373b8637e5a6f3cc1af1e081efb7e5b9d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 16 Apr 2024 13:08:34 -0700 Subject: [PATCH] fix(proxy_server.py): fix key create logic + add unit tests --- litellm/proxy/proxy_server.py | 80 ++++++++++++++++++----------------- tests/test_keys.py | 35 ++++++++++++++- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0f01f4e76..6125a3169 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1010,48 +1010,50 @@ async def user_api_key_auth( db=custom_db_client, ) ) - if route in LiteLLMRoutes.info_routes.value and ( - not _is_user_proxy_admin(user_id_information) - ): # check if user allowed to call an info route - if route == "/key/info": - # check if user can access this route - query_params = request.query_params - key = query_params.get("key") - if ( - key is not None - and prisma_client.hash_token(token=key) != api_key - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="user not allowed to access this key's info", - ) - elif route == "/user/info": - # check if user can access this route - query_params = request.query_params - user_id = query_params.get("user_id") - verbose_proxy_logger.debug( - f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}" - ) - if user_id != valid_token.user_id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="key not allowed to access this user's info", - ) - elif route == "/model/info": - # /model/info just shows models user has access to - pass - elif route == "/team/info": - # check if key can access this team's info - query_params = request.query_params - team_id = query_params.get("team_id") - if team_id != valid_token.team_id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="key not allowed to access this team's info", + + if not _is_user_proxy_admin(user_id_information): # if non-admin + if ( + route in LiteLLMRoutes.info_routes.value + ): # check if user allowed to call an info route + if route == "/key/info": + # check if user can access this route + query_params = request.query_params + key = query_params.get("key") + if ( + key is not None + and prisma_client.hash_token(token=key) != api_key + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="user not allowed to access this key's info", + ) + elif route == "/user/info": + # check if user can access this route + query_params = request.query_params + user_id = query_params.get("user_id") + verbose_proxy_logger.debug( + f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}" ) + if user_id != valid_token.user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="key not allowed to access this user's info", + ) + elif route == "/model/info": + # /model/info just shows models user has access to + pass + elif route == "/team/info": + # check if key can access this team's info + query_params = request.query_params + team_id = query_params.get("team_id") + if team_id != valid_token.team_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="key not allowed to access this team's info", + ) else: raise Exception( - f"Only master key can be used to generate, delete, update info for new keys/users." + f"Only master key can be used to generate, delete, update info for new keys/users/teams." ) # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions diff --git a/tests/test_keys.py b/tests/test_keys.py index b99e93b21..39787eb97 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -44,9 +44,13 @@ async def generate_key( models=["azure-models", "gpt-4", "dall-e-3"], max_parallel_requests: Optional[int] = None, user_id: Optional[str] = None, + calling_key="sk-1234", ): url = "http://0.0.0.0:4000/key/generate" - headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + headers = { + "Authorization": f"Bearer {calling_key}", + "Content-Type": "application/json", + } data = { "models": models, "aliases": {"mistral-7b": "gpt-3.5-turbo"}, @@ -80,6 +84,35 @@ async def test_key_gen(): await asyncio.gather(*tasks) +@pytest.mark.asyncio +async def test_key_gen_bad_key(): + """ + Test if you can create a key with a non-admin key, even with UI setup + """ + async with aiohttp.ClientSession() as session: + ## LOGIN TO UI + form_data = {"username": "admin", "password": "sk-1234"} + async with session.post( + "http://0.0.0.0:4000/login", data=form_data + ) as response: + assert ( + response.status == 200 + ) # Assuming the endpoint returns a 500 status code for error handling + text = await response.text() + print(text) + ## create user key with admin key -> expect to work + key_data = await generate_key(session=session, i=0, user_id="user-1234") + key = key_data["key"] + ## create new key with user key -> expect to fail + try: + await generate_key( + session=session, i=0, user_id="user-1234", calling_key=key + ) + pytest.fail("Expected to fail") + except Exception as e: + pass + + async def update_key(session, get_key): """ Make sure only models user has access to are returned