mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate (#7437)
* fix(key_management_endpoints.py): enforce user_id / team_id checks on key generate Fixes https://github.com/BerriAI/litellm/issues/7336 * test: fix tests
This commit is contained in:
parent
0774fc71ce
commit
08145fa89e
2 changed files with 170 additions and 15 deletions
|
@ -61,6 +61,38 @@ def _get_user_in_team(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_allowed_to_create_key(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Assert user only creates keys for themselves
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/7336
|
||||||
|
"""
|
||||||
|
## BASE CASE - PROXY ADMIN
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role is not None
|
||||||
|
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if user_id is not None:
|
||||||
|
assert (
|
||||||
|
user_id == user_api_key_dict.user_id
|
||||||
|
), "User can only create keys for themselves. Got user_id={}, Your ID={}".format(
|
||||||
|
user_id, user_api_key_dict.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if team_id is not None:
|
||||||
|
assert (
|
||||||
|
user_api_key_dict.team_id == team_id
|
||||||
|
), "User can only create keys for their own team. Got={}, Your Team ID={}".format(
|
||||||
|
team_id, user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _team_key_generation_team_member_check(
|
def _team_key_generation_team_member_check(
|
||||||
team_table: LiteLLM_TeamTableCachedObj,
|
team_table: LiteLLM_TeamTableCachedObj,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
@ -315,6 +347,24 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_is_allowed_to_create_key(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
user_id=data.user_id,
|
||||||
|
team_id=data.team_id,
|
||||||
|
)
|
||||||
|
except AssertionError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
# check if user set default key/generate params on config.yaml
|
# check if user set default key/generate params on config.yaml
|
||||||
if litellm.default_key_generate_params is not None:
|
if litellm.default_key_generate_params is not None:
|
||||||
for elem in data:
|
for elem in data:
|
||||||
|
|
|
@ -439,7 +439,14 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
|
||||||
request = GenerateKeyRequest(
|
request = GenerateKeyRequest(
|
||||||
models=["all-team-models"], team_id=created_team_id
|
models=["all-team-models"], team_id=created_team_id
|
||||||
)
|
)
|
||||||
key = await generate_key_fn(data=request)
|
key = await generate_key_fn(
|
||||||
|
data=request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
|
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
@ -1430,7 +1437,14 @@ def test_key_generate_with_custom_auth(prisma_client):
|
||||||
try:
|
try:
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
pytest.fail(f"Expected an exception. Got {key}")
|
pytest.fail(f"Expected an exception. Got {key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# this should fail
|
# this should fail
|
||||||
|
@ -1446,7 +1460,14 @@ def test_key_generate_with_custom_auth(prisma_client):
|
||||||
team_id="litellm-core-infra@gmail.com",
|
team_id="litellm-core-infra@gmail.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
key = await generate_key_fn(request_2)
|
key = await generate_key_fn(
|
||||||
|
request_2,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
|
||||||
|
@ -1466,7 +1487,14 @@ def test_call_with_key_over_budget(prisma_client):
|
||||||
async def test():
|
async def test():
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
request = GenerateKeyRequest(max_budget=0.00001)
|
request = GenerateKeyRequest(max_budget=0.00001)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
|
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
@ -1578,7 +1606,14 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
|
||||||
async def test():
|
async def test():
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
request = GenerateKeyRequest(max_budget=0.00001)
|
request = GenerateKeyRequest(max_budget=0.00001)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
|
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
@ -1731,7 +1766,14 @@ async def test_call_with_key_over_model_budget(
|
||||||
max_budget=100000, # the key itself has a very high budget
|
max_budget=100000, # the key itself has a very high budget
|
||||||
model_max_budget=model_max_budget,
|
model_max_budget=model_max_budget,
|
||||||
)
|
)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
|
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
@ -1798,7 +1840,14 @@ async def test_call_with_key_never_over_budget(prisma_client):
|
||||||
try:
|
try:
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
request = GenerateKeyRequest(max_budget=None)
|
request = GenerateKeyRequest(max_budget=None)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
|
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
@ -1882,7 +1931,14 @@ async def test_call_with_key_over_budget_stream(prisma_client):
|
||||||
try:
|
try:
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
request = GenerateKeyRequest(max_budget=0.00001)
|
request = GenerateKeyRequest(max_budget=0.00001)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
|
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
@ -2006,7 +2062,14 @@ async def test_key_name_null(prisma_client):
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
try:
|
try:
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print("generated key=", key)
|
print("generated key=", key)
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
result = await info_key_fn(
|
result = await info_key_fn(
|
||||||
|
@ -2035,7 +2098,14 @@ async def test_key_name_set(prisma_client):
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
try:
|
try:
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
result = await info_key_fn(
|
result = await info_key_fn(
|
||||||
key=generated_key,
|
key=generated_key,
|
||||||
|
@ -2062,7 +2132,14 @@ async def test_default_key_params(prisma_client):
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
try:
|
try:
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
result = await info_key_fn(
|
result = await info_key_fn(
|
||||||
key=generated_key,
|
key=generated_key,
|
||||||
|
@ -2093,7 +2170,14 @@ async def test_upperbound_key_param_larger_budget(prisma_client):
|
||||||
max_budget=200000,
|
max_budget=200000,
|
||||||
budget_duration="30d",
|
budget_duration="30d",
|
||||||
)
|
)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
# print(result)
|
# print(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert e.code == str(400)
|
assert e.code == str(400)
|
||||||
|
@ -2112,7 +2196,14 @@ async def test_upperbound_key_param_larger_duration(prisma_client):
|
||||||
max_budget=10,
|
max_budget=10,
|
||||||
duration="30d",
|
duration="30d",
|
||||||
)
|
)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
pytest.fail("Expected this to fail but it passed")
|
pytest.fail("Expected this to fail but it passed")
|
||||||
# print(result)
|
# print(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2131,7 +2222,14 @@ async def test_upperbound_key_param_none_duration(prisma_client):
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
try:
|
try:
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
print(key)
|
print(key)
|
||||||
# print(result)
|
# print(result)
|
||||||
|
@ -2362,7 +2460,14 @@ async def test_proxy_load_test_db(prisma_client):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
request = GenerateKeyRequest(max_budget=0.00001)
|
request = GenerateKeyRequest(max_budget=0.00001)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(
|
||||||
|
request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
print(key)
|
print(key)
|
||||||
|
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue