fix: fix test (#8501)

This commit is contained in:
Krish Dholakia 2025-02-12 18:38:15 -08:00 committed by GitHub
parent 12e6ae30dd
commit 1dc3c66630
2 changed files with 60 additions and 39 deletions

View file

@ -1676,15 +1676,41 @@ async def regenerate_key_fn(
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
def validate_key_list_check( async def validate_key_list_check(
complete_user_info: LiteLLM_UserTable, user_api_key_dict: UserAPIKeyAuth,
user_id: Optional[str], user_id: Optional[str],
team_id: Optional[str], team_id: Optional[str],
organization_id: Optional[str], organization_id: Optional[str],
key_alias: Optional[str], key_alias: Optional[str],
prisma_client: PrismaClient,
): ):
if complete_user_info.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return # proxy admin can see all keys if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return
if user_api_key_dict.user_id is None:
raise ProxyException(
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
type=ProxyErrorTypes.bad_request_error,
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
)
if complete_user_info_db_obj is None:
raise ProxyException(
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
type=ProxyErrorTypes.bad_request_error,
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info = LiteLLM_UserTable(**complete_user_info_db_obj.model_dump())
# internal user can only see their own keys # internal user can only see their own keys
if user_id: if user_id:
@ -1779,42 +1805,16 @@ async def list_keys(
verbose_proxy_logger.error("Database not connected") verbose_proxy_logger.error("Database not connected")
raise Exception("Database not connected") raise Exception("Database not connected")
if not user_api_key_dict.user_id: await validate_key_list_check(
raise ProxyException( user_api_key_dict=user_api_key_dict,
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
type=ProxyErrorTypes.bad_request_error,
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
)
if complete_user_info is None:
raise ProxyException(
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
type=ProxyErrorTypes.bad_request_error,
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_pydantic_obj = LiteLLM_UserTable(
**complete_user_info.model_dump()
)
validate_key_list_check(
complete_user_info=complete_user_info_pydantic_obj,
user_id=user_id, user_id=user_id,
team_id=team_id, team_id=team_id,
organization_id=organization_id, organization_id=organization_id,
key_alias=key_alias, key_alias=key_alias,
prisma_client=prisma_client,
) )
if user_id is None and complete_user_info_pydantic_obj.user_role != [ if user_id is None and user_api_key_dict.user_role not in [
LitellmUserRoles.PROXY_ADMIN.value, LitellmUserRoles.PROXY_ADMIN.value,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value, LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
]: ]:

View file

@ -3359,6 +3359,7 @@ async def test_list_keys(prisma_client):
from fastapi import Query from fastapi import Query
from litellm.proxy.proxy_server import hash_token from litellm.proxy.proxy_server import hash_token
from litellm.proxy._types import LitellmUserRoles
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
@ -3368,7 +3369,9 @@ async def test_list_keys(prisma_client):
request = Request(scope={"type": "http", "query_string": b""}) request = Request(scope={"type": "http", "query_string": b""})
response = await list_keys( response = await list_keys(
request, request,
UserAPIKeyAuth(), UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN.value,
),
page=1, page=1,
size=10, size=10,
) )
@ -3380,7 +3383,12 @@ async def test_list_keys(prisma_client):
assert "total_pages" in response assert "total_pages" in response
# Test pagination # Test pagination
response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2) response = await list_keys(
request,
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
page=1,
size=2,
)
print("pagination response=", response) print("pagination response=", response)
assert len(response["keys"]) == 2 assert len(response["keys"]) == 2
assert response["current_page"] == 1 assert response["current_page"] == 1
@ -3406,7 +3414,11 @@ async def test_list_keys(prisma_client):
# Test filtering by user_id # Test filtering by user_id
response = await list_keys( response = await list_keys(
request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10 request,
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
user_id=user_id,
page=1,
size=10,
) )
print("filtered user_id response=", response) print("filtered user_id response=", response)
assert len(response["keys"]) == 1 assert len(response["keys"]) == 1
@ -3414,7 +3426,11 @@ async def test_list_keys(prisma_client):
# Test filtering by key_alias # Test filtering by key_alias
response = await list_keys( response = await list_keys(
request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10 request,
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
key_alias=key_alias,
page=1,
size=10,
) )
assert len(response["keys"]) == 1 assert len(response["keys"]) == 1
assert _key in response["keys"] assert _key in response["keys"]
@ -3436,7 +3452,12 @@ async def test_key_list_unsupported_params(prisma_client):
request = Request(scope={"type": "http", "query_string": b"alias=foo"}) request = Request(scope={"type": "http", "query_string": b"alias=foo"})
try: try:
await list_keys(request, UserAPIKeyAuth(), page=1, size=10) await list_keys(
request,
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
page=1,
size=10,
)
pytest.fail("Expected this call to fail") pytest.fail("Expected this call to fail")
except Exception as e: except Exception as e:
print("error str=", str(e.message)) print("error str=", str(e.message))