[Feat] Add Error Handling for /key/list endpoint (#5787)

* raise error from unsupported param

* add testing for key list endpoint

* add testing for key list error handling

* fix key list test
This commit is contained in:
Ishaan Jaff 2024-09-19 17:14:12 -07:00 committed by GitHub
parent e6018a464f
commit 186db292ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 125 additions and 6 deletions

View file

@ -1134,6 +1134,7 @@ async def regenerate_key_fn(
)
@management_endpoint_wrapper
async def list_keys(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
page: int = Query(1, description="Page number", ge=1),
size: int = Query(10, description="Page size", ge=1, le=100),
@ -1152,13 +1153,24 @@ async def list_keys(
logging.error("Database not connected")
raise Exception("Database not connected")
# Check for unsupported parameters
supported_params = {"page", "size", "user_id", "team_id", "key_alias"}
unsupported_params = set(request.query_params.keys()) - supported_params
if unsupported_params:
raise ProxyException(
message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}",
type=ProxyErrorTypes.bad_request_error,
param=", ".join(unsupported_params),
code=status.HTTP_400_BAD_REQUEST,
)
# Prepare filter conditions
where = {}
if user_id:
if user_id and isinstance(user_id, str):
where["user_id"] = user_id
if team_id:
if team_id and isinstance(team_id, str):
where["team_id"] = team_id
if key_alias:
if key_alias and isinstance(key_alias, str):
where["key_alias"] = key_alias
logging.debug(f"Filter conditions: {where}")
@ -1206,9 +1218,18 @@ async def list_keys(
return response
except Exception as e:
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"error({str(e)})"),
type=ProxyErrorTypes.internal_server_error,
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message=f"Error listing keys: {str(e)}",
type=ProxyErrorTypes.internal_server_error, # Use the enum value
param=None,
message="Authentication Error, " + str(e),
type=ProxyErrorTypes.internal_server_error,
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View file

@ -56,6 +56,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_fn,
generate_key_helper_fn,
info_key_fn,
list_keys,
regenerate_key_fn,
update_key_fn,
)
@ -3123,3 +3124,100 @@ async def test_admin_only_routes(prisma_client):
pass
setattr(proxy_server, "general_settings", initial_general_settings)
@pytest.mark.asyncio
async def test_list_keys(prisma_client):
"""
Test the list_keys function:
- Test basic key
- Test pagination
- Test filtering by user_id, and key_alias
"""
from fastapi import Query
from litellm.proxy.proxy_server import hash_token
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
# Test basic listing
request = Request(scope={"type": "http", "query_string": b""})
response = await list_keys(
request,
UserAPIKeyAuth(),
page=1,
size=10,
)
print("response=", response)
assert "keys" in response
assert len(response["keys"]) > 0
assert "total_count" in response
assert "current_page" in response
assert "total_pages" in response
# Test pagination
response = await list_keys(request, UserAPIKeyAuth(), page=1, size=2)
print("pagination response=", response)
assert len(response["keys"]) == 2
assert response["current_page"] == 1
# Test filtering by user_id
unique_id = str(uuid.uuid4())
team_id = f"key-list-team-{unique_id}"
key_alias = f"key-list-alias-{unique_id}"
user_id = f"key-list-user-{unique_id}"
response = await new_user(
data=NewUserRequest(
user_id=f"key-list-user-{unique_id}",
user_role=LitellmUserRoles.INTERNAL_USER,
key_alias=f"key-list-alias-{unique_id}",
),
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
_key = hash_token(response.key)
await asyncio.sleep(2)
# Test filtering by user_id
response = await list_keys(
request, UserAPIKeyAuth(), user_id=user_id, page=1, size=10
)
print("filtered user_id response=", response)
assert len(response["keys"]) == 1
assert _key in response["keys"]
# Test filtering by key_alias
response = await list_keys(
request, UserAPIKeyAuth(), key_alias=key_alias, page=1, size=10
)
assert len(response["keys"]) == 1
assert _key in response["keys"]
@pytest.mark.asyncio
async def test_key_list_unsupported_params(prisma_client):
"""
Test the list_keys function:
- Test unsupported params
"""
from litellm.proxy.proxy_server import hash_token
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
try:
await list_keys(request, UserAPIKeyAuth(), page=1, size=10)
pytest.fail("Expected this call to fail")
except Exception as e:
print("error str=", str(e.message))
error_str = str(e.message)
assert "Unsupported parameter" in error_str
pass