fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate (#7437)

* fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate

Fixes https://github.com/BerriAI/litellm/issues/7336

* test: fix tests
This commit is contained in:
Krish Dholakia 2024-12-27 10:15:48 -08:00 committed by GitHub
parent 0774fc71ce
commit 08145fa89e
2 changed files with 170 additions and 15 deletions

View file

@ -61,6 +61,38 @@ def _get_user_in_team(
return None return None
def _is_allowed_to_create_key(
user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str]
) -> bool:
"""
Assert user only creates keys for themselves
Relevant issue: https://github.com/BerriAI/litellm/issues/7336
"""
## BASE CASE - PROXY ADMIN
if (
user_api_key_dict.user_role is not None
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return True
if user_id is not None:
assert (
user_id == user_api_key_dict.user_id
), "User can only create keys for themselves. Got user_id={}, Your ID={}".format(
user_id, user_api_key_dict.user_id
)
if team_id is not None:
assert (
user_api_key_dict.team_id == team_id
), "User can only create keys for their own team. Got={}, Your Team ID={}".format(
team_id, user_api_key_dict.team_id
)
return True
def _team_key_generation_team_member_check( def _team_key_generation_team_member_check(
team_table: LiteLLM_TeamTableCachedObj, team_table: LiteLLM_TeamTableCachedObj,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
@ -315,6 +347,24 @@ async def generate_key_fn( # noqa: PLR0915
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
data=data, data=data,
) )
try:
_is_allowed_to_create_key(
user_api_key_dict=user_api_key_dict,
user_id=data.user_id,
team_id=data.team_id,
)
except AssertionError as e:
raise HTTPException(
status_code=403,
detail=str(e),
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
)
# check if user set default key/generate params on config.yaml # check if user set default key/generate params on config.yaml
if litellm.default_key_generate_params is not None: if litellm.default_key_generate_params is not None:
for elem in data: for elem in data:

View file

@ -439,7 +439,14 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
request = GenerateKeyRequest( request = GenerateKeyRequest(
models=["all-team-models"], team_id=created_team_id models=["all-team-models"], team_id=created_team_id
) )
key = await generate_key_fn(data=request) key = await generate_key_fn(
data=request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key
@ -1430,7 +1437,14 @@ def test_key_generate_with_custom_auth(prisma_client):
try: try:
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest() request = GenerateKeyRequest()
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
pytest.fail(f"Expected an exception. Got {key}") pytest.fail(f"Expected an exception. Got {key}")
except Exception as e: except Exception as e:
# this should fail # this should fail
@ -1446,7 +1460,14 @@ def test_key_generate_with_custom_auth(prisma_client):
team_id="litellm-core-infra@gmail.com", team_id="litellm-core-infra@gmail.com",
) )
key = await generate_key_fn(request_2) key = await generate_key_fn(
request_2,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key
@ -1466,7 +1487,14 @@ def test_call_with_key_over_budget(prisma_client):
async def test(): async def test():
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001) request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key
@ -1578,7 +1606,14 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
async def test(): async def test():
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001) request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key
@ -1731,7 +1766,14 @@ async def test_call_with_key_over_model_budget(
max_budget=100000, # the key itself has a very high budget max_budget=100000, # the key itself has a very high budget
model_max_budget=model_max_budget, model_max_budget=model_max_budget,
) )
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key
@ -1798,7 +1840,14 @@ async def test_call_with_key_never_over_budget(prisma_client):
try: try:
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=None) request = GenerateKeyRequest(max_budget=None)
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key
@ -1882,7 +1931,14 @@ async def test_call_with_key_over_budget_stream(prisma_client):
try: try:
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001) request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key
@ -2006,7 +2062,14 @@ async def test_key_name_null(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
try: try:
request = GenerateKeyRequest() request = GenerateKeyRequest()
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print("generated key=", key) print("generated key=", key)
generated_key = key.key generated_key = key.key
result = await info_key_fn( result = await info_key_fn(
@ -2035,7 +2098,14 @@ async def test_key_name_set(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
try: try:
request = GenerateKeyRequest() request = GenerateKeyRequest()
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = key.key generated_key = key.key
result = await info_key_fn( result = await info_key_fn(
key=generated_key, key=generated_key,
@ -2062,7 +2132,14 @@ async def test_default_key_params(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
try: try:
request = GenerateKeyRequest() request = GenerateKeyRequest()
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = key.key generated_key = key.key
result = await info_key_fn( result = await info_key_fn(
key=generated_key, key=generated_key,
@ -2093,7 +2170,14 @@ async def test_upperbound_key_param_larger_budget(prisma_client):
max_budget=200000, max_budget=200000,
budget_duration="30d", budget_duration="30d",
) )
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
# print(result) # print(result)
except Exception as e: except Exception as e:
assert e.code == str(400) assert e.code == str(400)
@ -2112,7 +2196,14 @@ async def test_upperbound_key_param_larger_duration(prisma_client):
max_budget=10, max_budget=10,
duration="30d", duration="30d",
) )
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
pytest.fail("Expected this to fail but it passed") pytest.fail("Expected this to fail but it passed")
# print(result) # print(result)
except Exception as e: except Exception as e:
@ -2131,7 +2222,14 @@ async def test_upperbound_key_param_none_duration(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
try: try:
request = GenerateKeyRequest() request = GenerateKeyRequest()
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
# print(result) # print(result)
@ -2362,7 +2460,14 @@ async def test_proxy_load_test_db(prisma_client):
start_time = time.time() start_time = time.time()
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001) request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request) key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key) print(key)
generated_key = key.key generated_key = key.key