fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate (#7437)

* fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate

Fixes https://github.com/BerriAI/litellm/issues/7336

* test: fix tests
This commit is contained in:
Krish Dholakia 2024-12-27 10:15:48 -08:00 committed by GitHub
parent 0774fc71ce
commit 08145fa89e
2 changed files with 170 additions and 15 deletions

View file

@ -439,7 +439,14 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
request = GenerateKeyRequest(
models=["all-team-models"], team_id=created_team_id
)
key = await generate_key_fn(data=request)
key = await generate_key_fn(
data=request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
@ -1430,7 +1437,14 @@ def test_key_generate_with_custom_auth(prisma_client):
try:
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest()
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
pytest.fail(f"Expected an exception. Got {key}")
except Exception as e:
# this should fail
@ -1446,7 +1460,14 @@ def test_key_generate_with_custom_auth(prisma_client):
team_id="litellm-core-infra@gmail.com",
)
key = await generate_key_fn(request_2)
key = await generate_key_fn(
request_2,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
@ -1466,7 +1487,14 @@ def test_call_with_key_over_budget(prisma_client):
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
@ -1578,7 +1606,14 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
@ -1731,7 +1766,14 @@ async def test_call_with_key_over_model_budget(
max_budget=100000, # the key itself has a very high budget
model_max_budget=model_max_budget,
)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
@ -1798,7 +1840,14 @@ async def test_call_with_key_never_over_budget(prisma_client):
try:
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=None)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
@ -1882,7 +1931,14 @@ async def test_call_with_key_over_budget_stream(prisma_client):
try:
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
@ -2006,7 +2062,14 @@ async def test_key_name_null(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print("generated key=", key)
generated_key = key.key
result = await info_key_fn(
@ -2035,7 +2098,14 @@ async def test_key_name_set(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = key.key
result = await info_key_fn(
key=generated_key,
@ -2062,7 +2132,14 @@ async def test_default_key_params(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = key.key
result = await info_key_fn(
key=generated_key,
@ -2093,7 +2170,14 @@ async def test_upperbound_key_param_larger_budget(prisma_client):
max_budget=200000,
budget_duration="30d",
)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
# print(result)
except Exception as e:
assert e.code == str(400)
@ -2112,7 +2196,14 @@ async def test_upperbound_key_param_larger_duration(prisma_client):
max_budget=10,
duration="30d",
)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
pytest.fail("Expected this to fail but it passed")
# print(result)
except Exception as e:
@ -2131,7 +2222,14 @@ async def test_upperbound_key_param_none_duration(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
# print(result)
@ -2362,7 +2460,14 @@ async def test_proxy_load_test_db(prisma_client):
start_time = time.time()
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(max_budget=0.00001)
key = await generate_key_fn(request)
key = await generate_key_fn(
request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key