mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix: fix test (#8501)
This commit is contained in:
parent
12e6ae30dd
commit
1dc3c66630
2 changed files with 60 additions and 39 deletions
|
@ -1676,15 +1676,41 @@ async def regenerate_key_fn(
|
||||||
raise handle_exception_on_proxy(e)
|
raise handle_exception_on_proxy(e)
|
||||||
|
|
||||||
|
|
||||||
def validate_key_list_check(
|
async def validate_key_list_check(
|
||||||
complete_user_info: LiteLLM_UserTable,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
team_id: Optional[str],
|
team_id: Optional[str],
|
||||||
organization_id: Optional[str],
|
organization_id: Optional[str],
|
||||||
key_alias: Optional[str],
|
key_alias: Optional[str],
|
||||||
|
prisma_client: PrismaClient,
|
||||||
):
|
):
|
||||||
if complete_user_info.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
|
||||||
return # proxy admin can see all keys
|
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||||
|
return
|
||||||
|
|
||||||
|
if user_api_key_dict.user_id is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
||||||
|
type=ProxyErrorTypes.bad_request_error,
|
||||||
|
param="user_id",
|
||||||
|
code=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
complete_user_info_db_obj: Optional[BaseModel] = (
|
||||||
|
await prisma_client.db.litellm_usertable.find_unique(
|
||||||
|
where={"user_id": user_api_key_dict.user_id},
|
||||||
|
include={"organization_memberships": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if complete_user_info_db_obj is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
||||||
|
type=ProxyErrorTypes.bad_request_error,
|
||||||
|
param="user_id",
|
||||||
|
code=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
complete_user_info = LiteLLM_UserTable(**complete_user_info_db_obj.model_dump())
|
||||||
|
|
||||||
# internal user can only see their own keys
|
# internal user can only see their own keys
|
||||||
if user_id:
|
if user_id:
|
||||||
|
@ -1779,42 +1805,16 @@ async def list_keys(
|
||||||
verbose_proxy_logger.error("Database not connected")
|
verbose_proxy_logger.error("Database not connected")
|
||||||
raise Exception("Database not connected")
|
raise Exception("Database not connected")
|
||||||
|
|
||||||
if not user_api_key_dict.user_id:
|
await validate_key_list_check(
|
||||||
raise ProxyException(
|
user_api_key_dict=user_api_key_dict,
|
||||||
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
|
||||||
type=ProxyErrorTypes.bad_request_error,
|
|
||||||
param="user_id",
|
|
||||||
code=status.HTTP_403_FORBIDDEN,
|
|
||||||
)
|
|
||||||
|
|
||||||
complete_user_info: Optional[BaseModel] = (
|
|
||||||
await prisma_client.db.litellm_usertable.find_unique(
|
|
||||||
where={"user_id": user_api_key_dict.user_id},
|
|
||||||
include={"organization_memberships": True},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if complete_user_info is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message="You are not authorized to access this endpoint. No 'user_id' is associated with your API key.",
|
|
||||||
type=ProxyErrorTypes.bad_request_error,
|
|
||||||
param="user_id",
|
|
||||||
code=status.HTTP_403_FORBIDDEN,
|
|
||||||
)
|
|
||||||
|
|
||||||
complete_user_info_pydantic_obj = LiteLLM_UserTable(
|
|
||||||
**complete_user_info.model_dump()
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_key_list_check(
|
|
||||||
complete_user_info=complete_user_info_pydantic_obj,
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
key_alias=key_alias,
|
key_alias=key_alias,
|
||||||
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id is None and complete_user_info_pydantic_obj.user_role != [
|
if user_id is None and user_api_key_dict.user_role not in [
|
||||||
LitellmUserRoles.PROXY_ADMIN.value,
|
LitellmUserRoles.PROXY_ADMIN.value,
|
||||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||||
]:
|
]:
|
||||||
|
|
|
@ -3359,6 +3359,7 @@ async def test_list_keys(prisma_client):
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import hash_token
|
from litellm.proxy.proxy_server import hash_token
|
||||||
|
from litellm.proxy._types import LitellmUserRoles
|
||||||
|
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
@ -3368,7 +3369,9 @@ async def test_list_keys(prisma_client):
|
||||||
request = Request(scope={"type": "http", "query_string": b""})
|
request = Request(scope={"type": "http", "query_string": b""})
|
||||||
response = await list_keys(
|
response = await list_keys(
|
||||||
request,
|
request,
|
||||||
UserAPIKeyAuth(),
|
UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN.value,
|
||||||
|
),
|
||||||
page=1,
|
page=1,
|
||||||
size=10,
|
size=10,
|
||||||
)
|
)
|
||||||
|
@ -3380,7 +3383,12 @@ async def test_list_keys(prisma_client):
|
||||||
assert "total_pages" in response
|
assert "total_pages" in response
|
||||||
|
|
||||||
# Test pagination
|
# Test pagination
|
||||||
response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2)
|
response = await list_keys(
|
||||||
|
request,
|
||||||
|
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||||
|
page=1,
|
||||||
|
size=2,
|
||||||
|
)
|
||||||
print("pagination response=", response)
|
print("pagination response=", response)
|
||||||
assert len(response["keys"]) == 2
|
assert len(response["keys"]) == 2
|
||||||
assert response["current_page"] == 1
|
assert response["current_page"] == 1
|
||||||
|
@ -3406,7 +3414,11 @@ async def test_list_keys(prisma_client):
|
||||||
|
|
||||||
# Test filtering by user_id
|
# Test filtering by user_id
|
||||||
response = await list_keys(
|
response = await list_keys(
|
||||||
request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10
|
request,
|
||||||
|
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||||
|
user_id=user_id,
|
||||||
|
page=1,
|
||||||
|
size=10,
|
||||||
)
|
)
|
||||||
print("filtered user_id response=", response)
|
print("filtered user_id response=", response)
|
||||||
assert len(response["keys"]) == 1
|
assert len(response["keys"]) == 1
|
||||||
|
@ -3414,7 +3426,11 @@ async def test_list_keys(prisma_client):
|
||||||
|
|
||||||
# Test filtering by key_alias
|
# Test filtering by key_alias
|
||||||
response = await list_keys(
|
response = await list_keys(
|
||||||
request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10
|
request,
|
||||||
|
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||||
|
key_alias=key_alias,
|
||||||
|
page=1,
|
||||||
|
size=10,
|
||||||
)
|
)
|
||||||
assert len(response["keys"]) == 1
|
assert len(response["keys"]) == 1
|
||||||
assert _key in response["keys"]
|
assert _key in response["keys"]
|
||||||
|
@ -3436,7 +3452,12 @@ async def test_key_list_unsupported_params(prisma_client):
|
||||||
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
|
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await list_keys(request, UserAPIKeyAuth(), page=1, size=10)
|
await list_keys(
|
||||||
|
request,
|
||||||
|
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||||
|
page=1,
|
||||||
|
size=10,
|
||||||
|
)
|
||||||
pytest.fail("Expected this call to fail")
|
pytest.fail("Expected this call to fail")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("error str=", str(e.message))
|
print("error str=", str(e.message))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue