mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix: fix test (#8501)
This commit is contained in:
parent
54811cf595
commit
aee90f1dfe
2 changed files with 60 additions and 39 deletions
|
@ -1676,15 +1676,41 @@ async def regenerate_key_fn(
|
|||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
def validate_key_list_check(
|
||||
complete_user_info: LiteLLM_UserTable,
|
||||
async def validate_key_list_check(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
organization_id: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
if complete_user_info.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return # proxy admin can see all keys
|
||||
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return
|
||||
|
||||
if user_api_key_dict.user_id is None:
|
||||
raise ProxyException(
|
||||
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="user_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
complete_user_info_db_obj: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
)
|
||||
|
||||
if complete_user_info_db_obj is None:
|
||||
raise ProxyException(
|
||||
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="user_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
complete_user_info = LiteLLM_UserTable(**complete_user_info_db_obj.model_dump())
|
||||
|
||||
# internal user can only see their own keys
|
||||
if user_id:
|
||||
|
@ -1779,42 +1805,16 @@ async def list_keys(
|
|||
verbose_proxy_logger.error("Database not connected")
|
||||
raise Exception("Database not connected")
|
||||
|
||||
if not user_api_key_dict.user_id:
|
||||
raise ProxyException(
|
||||
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="user_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
complete_user_info: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
)
|
||||
|
||||
if complete_user_info is None:
|
||||
raise ProxyException(
|
||||
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="user_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
complete_user_info_pydantic_obj = LiteLLM_UserTable(
|
||||
**complete_user_info.model_dump()
|
||||
)
|
||||
|
||||
validate_key_list_check(
|
||||
complete_user_info=complete_user_info_pydantic_obj,
|
||||
await validate_key_list_check(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
organization_id=organization_id,
|
||||
key_alias=key_alias,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
if user_id is None and complete_user_info_pydantic_obj.user_role != [
|
||||
if user_id is None and user_api_key_dict.user_role not in [
|
||||
LitellmUserRoles.PROXY_ADMIN.value,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||
]:
|
||||
|
|
|
@ -3359,6 +3359,7 @@ async def test_list_keys(prisma_client):
|
|||
from fastapi import Query
|
||||
|
||||
from litellm.proxy.proxy_server import hash_token
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
|
@ -3368,7 +3369,9 @@ async def test_list_keys(prisma_client):
|
|||
request = Request(scope={"type": "http", "query_string": b""})
|
||||
response = await list_keys(
|
||||
request,
|
||||
UserAPIKeyAuth(),
|
||||
UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||
),
|
||||
page=1,
|
||||
size=10,
|
||||
)
|
||||
|
@ -3380,7 +3383,12 @@ async def test_list_keys(prisma_client):
|
|||
assert "total_pages" in response
|
||||
|
||||
# Test pagination
|
||||
response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2)
|
||||
response = await list_keys(
|
||||
request,
|
||||
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||
page=1,
|
||||
size=2,
|
||||
)
|
||||
print("pagination response=", response)
|
||||
assert len(response["keys"]) == 2
|
||||
assert response["current_page"] == 1
|
||||
|
@ -3406,7 +3414,11 @@ async def test_list_keys(prisma_client):
|
|||
|
||||
# Test filtering by user_id
|
||||
response = await list_keys(
|
||||
request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10
|
||||
request,
|
||||
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
size=10,
|
||||
)
|
||||
print("filtered user_id response=", response)
|
||||
assert len(response["keys"]) == 1
|
||||
|
@ -3414,7 +3426,11 @@ async def test_list_keys(prisma_client):
|
|||
|
||||
# Test filtering by key_alias
|
||||
response = await list_keys(
|
||||
request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10
|
||||
request,
|
||||
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||
key_alias=key_alias,
|
||||
page=1,
|
||||
size=10,
|
||||
)
|
||||
assert len(response["keys"]) == 1
|
||||
assert _key in response["keys"]
|
||||
|
@ -3436,7 +3452,12 @@ async def test_key_list_unsupported_params(prisma_client):
|
|||
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
|
||||
|
||||
try:
|
||||
await list_keys(request, UserAPIKeyAuth(), page=1, size=10)
|
||||
await list_keys(
|
||||
request,
|
||||
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||
page=1,
|
||||
size=10,
|
||||
)
|
||||
pytest.fail("Expected this call to fail")
|
||||
except Exception as e:
|
||||
print("error str=", str(e.message))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue