From 08145fa89ed33c6ac17e18e950d5b632b0ee3d33 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Fri, 27 Dec 2024 10:15:48 -0800 Subject: [PATCH] fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate (#7437) * fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate Fixes https://github.com/BerriAI/litellm/issues/7336 * test: fix tests --- .../key_management_endpoints.py | 50 +++++++ .../test_key_generate_prisma.py | 135 ++++++++++++++++-- 2 files changed, 170 insertions(+), 15 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index caf48e4342..526cc05c9a 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -61,6 +61,38 @@ def _get_user_in_team( return None +def _is_allowed_to_create_key( + user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str] +) -> bool: + """ + Assert user only creates keys for themselves + + Relevant issue: https://github.com/BerriAI/litellm/issues/7336 + """ + ## BASE CASE - PROXY ADMIN + if ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): + return True + + if user_id is not None: + assert ( + user_id == user_api_key_dict.user_id + ), "User can only create keys for themselves. Got user_id={}, Your ID={}".format( + user_id, user_api_key_dict.user_id + ) + + if team_id is not None: + assert ( + user_api_key_dict.team_id == team_id + ), "User can only create keys for their own team. Got={}, Your Team ID={}".format( + team_id, user_api_key_dict.team_id + ) + + return True + + def _team_key_generation_team_member_check( team_table: LiteLLM_TeamTableCachedObj, user_api_key_dict: UserAPIKeyAuth, @@ -315,6 +347,24 @@ async def generate_key_fn( # noqa: PLR0915 user_api_key_dict=user_api_key_dict, data=data, ) + + try: + _is_allowed_to_create_key( + user_api_key_dict=user_api_key_dict, + user_id=data.user_id, + team_id=data.team_id, + ) + except AssertionError as e: + raise HTTPException( + status_code=403, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) + # check if user set default key/generate params on config.yaml if litellm.default_key_generate_params is not None: for elem in data: diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index d2b0c765d5..3d2ef87573 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -439,7 +439,14 @@ async def test_call_with_valid_model_using_all_models(prisma_client): request = GenerateKeyRequest( models=["all-team-models"], team_id=created_team_id ) - key = await generate_key_fn(data=request) + key = await generate_key_fn( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1430,7 +1437,14 @@ def test_key_generate_with_custom_auth(prisma_client): try: await litellm.proxy.proxy_server.prisma_client.connect() request = GenerateKeyRequest() - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) pytest.fail(f"Expected an exception. Got {key}") except Exception as e: # this should fail @@ -1446,7 +1460,14 @@ def test_key_generate_with_custom_auth(prisma_client): team_id="litellm-core-infra@gmail.com", ) - key = await generate_key_fn(request_2) + key = await generate_key_fn( + request_2, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1466,7 +1487,14 @@ def test_call_with_key_over_budget(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = GenerateKeyRequest(max_budget=0.00001) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1578,7 +1606,14 @@ def test_call_with_key_over_budget_no_cache(prisma_client): async def test(): await litellm.proxy.proxy_server.prisma_client.connect() request = GenerateKeyRequest(max_budget=0.00001) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1731,7 +1766,14 @@ async def test_call_with_key_over_model_budget( max_budget=100000, # the key itself has a very high budget model_max_budget=model_max_budget, ) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1798,7 +1840,14 @@ async def test_call_with_key_never_over_budget(prisma_client): try: await litellm.proxy.proxy_server.prisma_client.connect() request = GenerateKeyRequest(max_budget=None) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -1882,7 +1931,14 @@ async def test_call_with_key_over_budget_stream(prisma_client): try: await litellm.proxy.proxy_server.prisma_client.connect() request = GenerateKeyRequest(max_budget=0.00001) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key @@ -2006,7 +2062,14 @@ async def test_key_name_null(prisma_client): await litellm.proxy.proxy_server.prisma_client.connect() try: request = GenerateKeyRequest() - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print("generated key=", key) generated_key = key.key result = await info_key_fn( @@ -2035,7 +2098,14 @@ async def test_key_name_set(prisma_client): await litellm.proxy.proxy_server.prisma_client.connect() try: request = GenerateKeyRequest() - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) generated_key = key.key result = await info_key_fn( key=generated_key, @@ -2062,7 +2132,14 @@ async def test_default_key_params(prisma_client): await litellm.proxy.proxy_server.prisma_client.connect() try: request = GenerateKeyRequest() - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) generated_key = key.key result = await info_key_fn( key=generated_key, @@ -2093,7 +2170,14 @@ async def test_upperbound_key_param_larger_budget(prisma_client): max_budget=200000, budget_duration="30d", ) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) # print(result) except Exception as e: assert e.code == str(400) @@ -2112,7 +2196,14 @@ async def test_upperbound_key_param_larger_duration(prisma_client): max_budget=10, duration="30d", ) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) pytest.fail("Expected this to fail but it passed") # print(result) except Exception as e: @@ -2131,7 +2222,14 @@ async def test_upperbound_key_param_none_duration(prisma_client): await litellm.proxy.proxy_server.prisma_client.connect() try: request = GenerateKeyRequest() - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) # print(result) @@ -2362,7 +2460,14 @@ async def test_proxy_load_test_db(prisma_client): start_time = time.time() await litellm.proxy.proxy_server.prisma_client.connect() request = GenerateKeyRequest(max_budget=0.00001) - key = await generate_key_fn(request) + key = await generate_key_fn( + request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) print(key) generated_key = key.key