forked from phoenix/litellm-mirror
[Feat] Add Error Handling for /key/list endpoint (#5787)
* raise error from unsupported param * add testing for key list endpoint * add testing for key list error handling * fix key list test
This commit is contained in:
parent
e6018a464f
commit
186db292ae
2 changed files with 125 additions and 6 deletions
|
@ -1134,6 +1134,7 @@ async def regenerate_key_fn(
|
|||
)
|
||||
@management_endpoint_wrapper
|
||||
async def list_keys(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
page: int = Query(1, description="Page number", ge=1),
|
||||
size: int = Query(10, description="Page size", ge=1, le=100),
|
||||
|
@ -1152,13 +1153,24 @@ async def list_keys(
|
|||
logging.error("Database not connected")
|
||||
raise Exception("Database not connected")
|
||||
|
||||
# Check for unsupported parameters
|
||||
supported_params = {"page", "size", "user_id", "team_id", "key_alias"}
|
||||
unsupported_params = set(request.query_params.keys()) - supported_params
|
||||
if unsupported_params:
|
||||
raise ProxyException(
|
||||
message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param=", ".join(unsupported_params),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Prepare filter conditions
|
||||
where = {}
|
||||
if user_id:
|
||||
if user_id and isinstance(user_id, str):
|
||||
where["user_id"] = user_id
|
||||
if team_id:
|
||||
if team_id and isinstance(team_id, str):
|
||||
where["team_id"] = team_id
|
||||
if key_alias:
|
||||
if key_alias and isinstance(key_alias, str):
|
||||
where["key_alias"] = key_alias
|
||||
|
||||
logging.debug(f"Filter conditions: {where}")
|
||||
|
@ -1206,9 +1218,18 @@ async def list_keys(
|
|||
return response
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"error({str(e)})"),
|
||||
type=ProxyErrorTypes.internal_server_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message=f"Error listing keys: {str(e)}",
|
||||
type=ProxyErrorTypes.internal_server_error, # Use the enum value
|
||||
param=None,
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.internal_server_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
|
|
@ -56,6 +56,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|||
generate_key_fn,
|
||||
generate_key_helper_fn,
|
||||
info_key_fn,
|
||||
list_keys,
|
||||
regenerate_key_fn,
|
||||
update_key_fn,
|
||||
)
|
||||
|
@ -3123,3 +3124,100 @@ async def test_admin_only_routes(prisma_client):
|
|||
pass
|
||||
|
||||
setattr(proxy_server, "general_settings", initial_general_settings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys(prisma_client):
|
||||
"""
|
||||
Test the list_keys function:
|
||||
- Test basic key
|
||||
- Test pagination
|
||||
- Test filtering by user_id, and key_alias
|
||||
"""
|
||||
from fastapi import Query
|
||||
|
||||
from litellm.proxy.proxy_server import hash_token
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# Test basic listing
|
||||
request = Request(scope={"type": "http", "query_string": b""})
|
||||
response = await list_keys(
|
||||
request,
|
||||
UserAPIKeyAuth(),
|
||||
page=1,
|
||||
size=10,
|
||||
)
|
||||
print("response=", response)
|
||||
assert "keys" in response
|
||||
assert len(response["keys"]) > 0
|
||||
assert "total_count" in response
|
||||
assert "current_page" in response
|
||||
assert "total_pages" in response
|
||||
|
||||
# Test pagination
|
||||
response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2)
|
||||
print("pagination response=", response)
|
||||
assert len(response["keys"]) == 2
|
||||
assert response["current_page"] == 1
|
||||
|
||||
# Test filtering by user_id
|
||||
|
||||
unique_id = str(uuid.uuid4())
|
||||
team_id = f"key-list-team-{unique_id}"
|
||||
key_alias = f"key-list-alias-{unique_id}"
|
||||
user_id = f"key-list-user-{unique_id}"
|
||||
response = await new_user(
|
||||
data=NewUserRequest(
|
||||
user_id=f"key-list-user-{unique_id}",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
key_alias=f"key-list-alias-{unique_id}",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
|
||||
_key = hash_token(response.key)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test filtering by user_id
|
||||
response = await list_keys(
|
||||
request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10
|
||||
)
|
||||
print("filtered user_id response=", response)
|
||||
assert len(response["keys"]) == 1
|
||||
assert _key in response["keys"]
|
||||
|
||||
# Test filtering by key_alias
|
||||
response = await list_keys(
|
||||
request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10
|
||||
)
|
||||
assert len(response["keys"]) == 1
|
||||
assert _key in response["keys"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_list_unsupported_params(prisma_client):
|
||||
"""
|
||||
Test the list_keys function:
|
||||
- Test unsupported params
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import hash_token
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
|
||||
|
||||
try:
|
||||
await list_keys(request, UserAPIKeyAuth(), page=1, size=10)
|
||||
pytest.fail("Expected this call to fail")
|
||||
except Exception as e:
|
||||
print("error str=", str(e.message))
|
||||
error_str = str(e.message)
|
||||
assert "Unsupported parameter" in error_str
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue