mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate (#7437)
* fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate Fixes https://github.com/BerriAI/litellm/issues/7336 * test: fix tests
This commit is contained in:
parent
0774fc71ce
commit
08145fa89e
2 changed files with 170 additions and 15 deletions
|
@ -61,6 +61,38 @@ def _get_user_in_team(
|
|||
return None
|
||||
|
||||
|
||||
def _is_allowed_to_create_key(
|
||||
user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Assert user only creates keys for themselves
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/7336
|
||||
"""
|
||||
## BASE CASE - PROXY ADMIN
|
||||
if (
|
||||
user_api_key_dict.user_role is not None
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
):
|
||||
return True
|
||||
|
||||
if user_id is not None:
|
||||
assert (
|
||||
user_id == user_api_key_dict.user_id
|
||||
), "User can only create keys for themselves. Got user_id={}, Your ID={}".format(
|
||||
user_id, user_api_key_dict.user_id
|
||||
)
|
||||
|
||||
if team_id is not None:
|
||||
assert (
|
||||
user_api_key_dict.team_id == team_id
|
||||
), "User can only create keys for their own team. Got={}, Your Team ID={}".format(
|
||||
team_id, user_api_key_dict.team_id
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _team_key_generation_team_member_check(
|
||||
team_table: LiteLLM_TeamTableCachedObj,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -315,6 +347,24 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
)
|
||||
|
||||
try:
|
||||
_is_allowed_to_create_key(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
user_id=data.user_id,
|
||||
team_id=data.team_id,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.default_key_generate_params is not None:
|
||||
for elem in data:
|
||||
|
|
|
@ -439,7 +439,14 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
|
|||
request = GenerateKeyRequest(
|
||||
models=["all-team-models"], team_id=created_team_id
|
||||
)
|
||||
key = await generate_key_fn(data=request)
|
||||
key = await generate_key_fn(
|
||||
data=request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
|
@ -1430,7 +1437,14 @@ def test_key_generate_with_custom_auth(prisma_client):
|
|||
try:
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
pytest.fail(f"Expected an exception. Got {key}")
|
||||
except Exception as e:
|
||||
# this should fail
|
||||
|
@ -1446,7 +1460,14 @@ def test_key_generate_with_custom_auth(prisma_client):
|
|||
team_id="litellm-core-infra@gmail.com",
|
||||
)
|
||||
|
||||
key = await generate_key_fn(request_2)
|
||||
key = await generate_key_fn(
|
||||
request_2,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
generated_key = key.key
|
||||
|
||||
|
@ -1466,7 +1487,14 @@ def test_call_with_key_over_budget(prisma_client):
|
|||
async def test():
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = GenerateKeyRequest(max_budget=0.00001)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
|
@ -1578,7 +1606,14 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
|
|||
async def test():
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = GenerateKeyRequest(max_budget=0.00001)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
|
@ -1731,7 +1766,14 @@ async def test_call_with_key_over_model_budget(
|
|||
max_budget=100000, # the key itself has a very high budget
|
||||
model_max_budget=model_max_budget,
|
||||
)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
|
@ -1798,7 +1840,14 @@ async def test_call_with_key_never_over_budget(prisma_client):
|
|||
try:
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = GenerateKeyRequest(max_budget=None)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
|
@ -1882,7 +1931,14 @@ async def test_call_with_key_over_budget_stream(prisma_client):
|
|||
try:
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = GenerateKeyRequest(max_budget=0.00001)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
|
@ -2006,7 +2062,14 @@ async def test_key_name_null(prisma_client):
|
|||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
try:
|
||||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print("generated key=", key)
|
||||
generated_key = key.key
|
||||
result = await info_key_fn(
|
||||
|
@ -2035,7 +2098,14 @@ async def test_key_name_set(prisma_client):
|
|||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
try:
|
||||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
generated_key = key.key
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
|
@ -2062,7 +2132,14 @@ async def test_default_key_params(prisma_client):
|
|||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
try:
|
||||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
generated_key = key.key
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
|
@ -2093,7 +2170,14 @@ async def test_upperbound_key_param_larger_budget(prisma_client):
|
|||
max_budget=200000,
|
||||
budget_duration="30d",
|
||||
)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
# print(result)
|
||||
except Exception as e:
|
||||
assert e.code == str(400)
|
||||
|
@ -2112,7 +2196,14 @@ async def test_upperbound_key_param_larger_duration(prisma_client):
|
|||
max_budget=10,
|
||||
duration="30d",
|
||||
)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
pytest.fail("Expected this to fail but it passed")
|
||||
# print(result)
|
||||
except Exception as e:
|
||||
|
@ -2131,7 +2222,14 @@ async def test_upperbound_key_param_none_duration(prisma_client):
|
|||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
try:
|
||||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
print(key)
|
||||
# print(result)
|
||||
|
@ -2362,7 +2460,14 @@ async def test_proxy_load_test_db(prisma_client):
|
|||
start_time = time.time()
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = GenerateKeyRequest(max_budget=0.00001)
|
||||
key = await generate_key_fn(request)
|
||||
key = await generate_key_fn(
|
||||
request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue