fix(proxy_server.py): have /user/info return user info + related user keys

This commit is contained in:
Krrish Dholakia 2024-01-12 22:52:13 +05:30
parent d29a9029ac
commit cd639a7f4b
2 changed files with 53 additions and 26 deletions

View file

@ -242,8 +242,20 @@ async def user_api_key_auth(
## check db
verbose_proxy_logger.debug(f"api key: {api_key}")
valid_token = await prisma_client.get_data(
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)
token=api_key,
)
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
# Token exists, now check expiration.
if valid_token.expires is not None and expires is not None:
if valid_token.expires >= expires:
# Token exists and is not expired.
return valid_token
else:
# Token exists but is expired.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
@ -1998,14 +2010,33 @@ async def user_auth(request: Request):
return "Email sent!"
@router.post(
@router.get(
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
)
async def user_info(request: Request):
async def user_info(
user_id: str = fastapi.Query(..., description="User ID in the request parameters")
):
"""
[TODO]: Use this to get user information. (user row + all user key info)
Use this to get user information. (user row + all user key info)
"""
pass
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
## GET USER ROW ##
user_info = await prisma_client.get_data(user_id=user_id)
## GET ALL KEYS ##
keys = await prisma_client.get_data(
user_id=user_id, table_name="key", query_type="find_all"
)
return {"user_id": user_id, "user_info": user_info, "keys": keys}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.post(