Merge pull request #5586 from BerriAI/litellm_add_key_list_endpoint

Feat - Proxy add /key/list endpoint
This commit is contained in:
Ishaan Jaff 2024-09-07 19:04:05 -07:00 committed by GitHub
commit 79d605d2dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 93 additions and 4 deletions

View file

@ -268,9 +268,7 @@ class LiteLLMRoutes(enum.Enum):
]
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes: List = [
"/global/spend/reset",
]
master_key_only_routes: List = ["/global/spend/reset", "/key/list"]
sso_only_routes: List = [
"/key/generate",

View file

@ -20,7 +20,7 @@ from datetime import datetime, timedelta, timezone
from typing import List, Optional
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
import litellm
from litellm._logging import verbose_proxy_logger
@ -1077,3 +1077,94 @@ async def regenerate_key_fn(
return GenerateKeyResponse(
**updated_token_dict,
)
@router.get(
"/key/list",
tags=["key management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def list_keys(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
page: int = Query(1, description="Page number", ge=1),
size: int = Query(10, description="Page size", ge=1, le=100),
user_id: Optional[str] = Query(None, description="Filter keys by user ID"),
team_id: Optional[str] = Query(None, description="Filter keys by team ID"),
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
):
try:
import logging
from litellm.proxy.proxy_server import prisma_client
logging.debug("Entering list_keys function")
if prisma_client is None:
logging.error("Database not connected")
raise Exception("Database not connected")
# Prepare filter conditions
where = {}
if user_id:
where["user_id"] = user_id
if team_id:
where["team_id"] = team_id
if key_alias:
where["key_alias"] = key_alias
logging.debug(f"Filter conditions: {where}")
# Calculate skip for pagination
skip = (page - 1) * size
logging.debug(f"Pagination: skip={skip}, take={size}")
# Fetch keys with pagination
keys = await prisma_client.db.litellm_verificationtoken.find_many(
where=where, # type: ignore
skip=skip, # type: ignore
take=size, # type: ignore
)
logging.debug(f"Fetched {len(keys)} keys")
# Get total count of keys
total_count = await prisma_client.db.litellm_verificationtoken.count(
where=where # type: ignore
)
logging.debug(f"Total count of keys: {total_count}")
# Calculate total pages
total_pages = -(-total_count // size) # Ceiling division
# Prepare response
key_list = []
for key in keys:
key_dict = key.dict()
_token = key_dict.get("token")
key_list.append(_token)
response = {
"keys": key_list,
"total_count": total_count,
"current_page": page,
"total_pages": total_pages,
}
logging.debug("Successfully prepared response")
return response
except Exception as e:
logging.error(f"Error in list_keys: {str(e)}")
logging.error(f"Error type: {type(e)}")
logging.error(f"Error traceback: {traceback.format_exc()}")
raise ProxyException(
message=f"Error listing keys: {str(e)}",
type=ProxyErrorTypes.internal_server_error, # Use the enum value
param=None,
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)