fix(proxy_server.py): have /user/info return user info + related user keys

This commit is contained in:
Krrish Dholakia 2024-01-12 22:52:13 +05:30
parent d29a9029ac
commit cd639a7f4b
2 changed files with 53 additions and 26 deletions

View file

@ -242,8 +242,20 @@ async def user_api_key_auth(
## check db ## check db
verbose_proxy_logger.debug(f"api key: {api_key}") verbose_proxy_logger.debug(f"api key: {api_key}")
valid_token = await prisma_client.get_data( valid_token = await prisma_client.get_data(
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc) token=api_key,
) )
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
# Token exists, now check expiration.
if valid_token.expires is not None and expires is not None:
if valid_token.expires >= expires:
# Token exists and is not expired.
return valid_token
else:
# Token exists but is expired.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}") verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None: elif valid_token is not None:
@ -1998,14 +2010,33 @@ async def user_auth(request: Request):
return "Email sent!" return "Email sent!"
@router.post( @router.get(
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)] "/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
) )
async def user_info(request: Request): async def user_info(
user_id: str = fastapi.Query(..., description="User ID in the request parameters")
):
""" """
[TODO]: Use this to get user information. (user row + all user key info) Use this to get user information. (user row + all user key info)
""" """
pass global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
## GET USER ROW ##
user_info = await prisma_client.get_data(user_id=user_id)
## GET ALL KEYS ##
keys = await prisma_client.get_data(
user_id=user_id, table_name="key", query_type="find_all"
)
return {"user_id": user_id, "user_info": user_info, "keys": keys}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.post( @router.post(

View file

@ -344,35 +344,31 @@ class PrismaClient:
async def get_data( async def get_data(
self, self,
token: Optional[str] = None, token: Optional[str] = None,
expires: Optional[Any] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique",
): ):
try: try:
print_verbose("PrismaClient: get_data") print_verbose("PrismaClient: get_data")
response = None response: Any = None
if token is not None: if token is not None or (table_name is not None and table_name == "key"):
# check if plain text or hash # check if plain text or hash
hashed_token = token if token is not None:
if token.startswith("sk-"): hashed_token = token
hashed_token = self.hash_token(token=token) if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
print_verbose("PrismaClient: find_unique") print_verbose("PrismaClient: find_unique")
response = await self.db.litellm_verificationtoken.find_unique( if query_type == "find_unique":
where={"token": hashed_token} response = await self.db.litellm_verificationtoken.find_unique(
) where={"token": hashed_token}
)
elif query_type == "find_all" and user_id is not None:
response = await self.db.litellm_verificationtoken.find_many(
where={"user_id": user_id}
)
print_verbose(f"PrismaClient: response={response}") print_verbose(f"PrismaClient: response={response}")
if response: if response is not None:
# Token exists, now check expiration.
if response.expires is not None and expires is not None:
if response.expires >= expires:
# Token exists and is not expired.
return response
else:
# Token exists but is expired.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
return response return response
else: else:
# Token does not exist. # Token does not exist.