forked from phoenix/litellm-mirror
[Feat] Add Error Handling for /key/list endpoint (#5787)
* raise error from unsupported param * add testing for key list endpoint * add testing for key list error handling * fix key list test
This commit is contained in:
parent
e6018a464f
commit
186db292ae
2 changed files with 125 additions and 6 deletions
|
@ -1134,6 +1134,7 @@ async def regenerate_key_fn(
|
||||||
)
|
)
|
||||||
@management_endpoint_wrapper
|
@management_endpoint_wrapper
|
||||||
async def list_keys(
|
async def list_keys(
|
||||||
|
request: Request,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
page: int = Query(1, description="Page number", ge=1),
|
page: int = Query(1, description="Page number", ge=1),
|
||||||
size: int = Query(10, description="Page size", ge=1, le=100),
|
size: int = Query(10, description="Page size", ge=1, le=100),
|
||||||
|
@ -1152,13 +1153,24 @@ async def list_keys(
|
||||||
logging.error("Database not connected")
|
logging.error("Database not connected")
|
||||||
raise Exception("Database not connected")
|
raise Exception("Database not connected")
|
||||||
|
|
||||||
|
# Check for unsupported parameters
|
||||||
|
supported_params = {"page", "size", "user_id", "team_id", "key_alias"}
|
||||||
|
unsupported_params = set(request.query_params.keys()) - supported_params
|
||||||
|
if unsupported_params:
|
||||||
|
raise ProxyException(
|
||||||
|
message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}",
|
||||||
|
type=ProxyErrorTypes.bad_request_error,
|
||||||
|
param=", ".join(unsupported_params),
|
||||||
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare filter conditions
|
# Prepare filter conditions
|
||||||
where = {}
|
where = {}
|
||||||
if user_id:
|
if user_id and isinstance(user_id, str):
|
||||||
where["user_id"] = user_id
|
where["user_id"] = user_id
|
||||||
if team_id:
|
if team_id and isinstance(team_id, str):
|
||||||
where["team_id"] = team_id
|
where["team_id"] = team_id
|
||||||
if key_alias:
|
if key_alias and isinstance(key_alias, str):
|
||||||
where["key_alias"] = key_alias
|
where["key_alias"] = key_alias
|
||||||
|
|
||||||
logging.debug(f"Filter conditions: {where}")
|
logging.debug(f"Filter conditions: {where}")
|
||||||
|
@ -1206,9 +1218,18 @@ async def list_keys(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "detail", f"error({str(e)})"),
|
||||||
|
type=ProxyErrorTypes.internal_server_error,
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||||
|
)
|
||||||
|
elif isinstance(e, ProxyException):
|
||||||
|
raise e
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=f"Error listing keys: {str(e)}",
|
message="Authentication Error, " + str(e),
|
||||||
type=ProxyErrorTypes.internal_server_error, # Use the enum value
|
type=ProxyErrorTypes.internal_server_error,
|
||||||
param=None,
|
param=getattr(e, "param", "None"),
|
||||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
|
|
@ -56,6 +56,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
generate_key_fn,
|
generate_key_fn,
|
||||||
generate_key_helper_fn,
|
generate_key_helper_fn,
|
||||||
info_key_fn,
|
info_key_fn,
|
||||||
|
list_keys,
|
||||||
regenerate_key_fn,
|
regenerate_key_fn,
|
||||||
update_key_fn,
|
update_key_fn,
|
||||||
)
|
)
|
||||||
|
@ -3123,3 +3124,100 @@ async def test_admin_only_routes(prisma_client):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
setattr(proxy_server, "general_settings", initial_general_settings)
|
setattr(proxy_server, "general_settings", initial_general_settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys(prisma_client):
|
||||||
|
"""
|
||||||
|
Test the list_keys function:
|
||||||
|
- Test basic key
|
||||||
|
- Test pagination
|
||||||
|
- Test filtering by user_id, and key_alias
|
||||||
|
"""
|
||||||
|
from fastapi import Query
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import hash_token
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
# Test basic listing
|
||||||
|
request = Request(scope={"type": "http", "query_string": b""})
|
||||||
|
response = await list_keys(
|
||||||
|
request,
|
||||||
|
UserAPIKeyAuth(),
|
||||||
|
page=1,
|
||||||
|
size=10,
|
||||||
|
)
|
||||||
|
print("response=", response)
|
||||||
|
assert "keys" in response
|
||||||
|
assert len(response["keys"]) > 0
|
||||||
|
assert "total_count" in response
|
||||||
|
assert "current_page" in response
|
||||||
|
assert "total_pages" in response
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2)
|
||||||
|
print("pagination response=", response)
|
||||||
|
assert len(response["keys"]) == 2
|
||||||
|
assert response["current_page"] == 1
|
||||||
|
|
||||||
|
# Test filtering by user_id
|
||||||
|
|
||||||
|
unique_id = str(uuid.uuid4())
|
||||||
|
team_id = f"key-list-team-{unique_id}"
|
||||||
|
key_alias = f"key-list-alias-{unique_id}"
|
||||||
|
user_id = f"key-list-user-{unique_id}"
|
||||||
|
response = await new_user(
|
||||||
|
data=NewUserRequest(
|
||||||
|
user_id=f"key-list-user-{unique_id}",
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
key_alias=f"key-list-alias-{unique_id}",
|
||||||
|
),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
|
|
||||||
|
_key = hash_token(response.key)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Test filtering by user_id
|
||||||
|
response = await list_keys(
|
||||||
|
request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10
|
||||||
|
)
|
||||||
|
print("filtered user_id response=", response)
|
||||||
|
assert len(response["keys"]) == 1
|
||||||
|
assert _key in response["keys"]
|
||||||
|
|
||||||
|
# Test filtering by key_alias
|
||||||
|
response = await list_keys(
|
||||||
|
request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10
|
||||||
|
)
|
||||||
|
assert len(response["keys"]) == 1
|
||||||
|
assert _key in response["keys"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_list_unsupported_params(prisma_client):
|
||||||
|
"""
|
||||||
|
Test the list_keys function:
|
||||||
|
- Test unsupported params
|
||||||
|
"""
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import hash_token
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
|
||||||
|
|
||||||
|
try:
|
||||||
|
await list_keys(request, UserAPIKeyAuth(), page=1, size=10)
|
||||||
|
pytest.fail("Expected this call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
print("error str=", str(e.message))
|
||||||
|
error_str = str(e.message)
|
||||||
|
assert "Unsupported parameter" in error_str
|
||||||
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue