mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(proxy_server.py): have /user/info return user info + related user keys
This commit is contained in:
parent
d29a9029ac
commit
cd639a7f4b
2 changed files with 53 additions and 26 deletions
|
@ -242,7 +242,19 @@ async def user_api_key_auth(
|
||||||
## check db
|
## check db
|
||||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||||
valid_token = await prisma_client.get_data(
|
valid_token = await prisma_client.get_data(
|
||||||
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)
|
token=api_key,
|
||||||
|
)
|
||||||
|
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||||
|
# Token exists, now check expiration.
|
||||||
|
if valid_token.expires is not None and expires is not None:
|
||||||
|
if valid_token.expires >= expires:
|
||||||
|
# Token exists and is not expired.
|
||||||
|
return valid_token
|
||||||
|
else:
|
||||||
|
# Token exists but is expired.
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="expired user key",
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}")
|
verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}")
|
||||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||||
|
@ -1998,14 +2010,33 @@ async def user_auth(request: Request):
|
||||||
return "Email sent!"
|
return "Email sent!"
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.get(
|
||||||
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
)
|
)
|
||||||
async def user_info(request: Request):
|
async def user_info(
|
||||||
|
user_id: str = fastapi.Query(..., description="User ID in the request parameters")
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
[TODO]: Use this to get user information. (user row + all user key info)
|
Use this to get user information. (user row + all user key info)
|
||||||
"""
|
"""
|
||||||
pass
|
global prisma_client
|
||||||
|
try:
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
|
)
|
||||||
|
## GET USER ROW ##
|
||||||
|
user_info = await prisma_client.get_data(user_id=user_id)
|
||||||
|
## GET ALL KEYS ##
|
||||||
|
keys = await prisma_client.get_data(
|
||||||
|
user_id=user_id, table_name="key", query_type="find_all"
|
||||||
|
)
|
||||||
|
return {"user_id": user_id, "user_info": user_info, "keys": keys}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
|
@ -344,35 +344,31 @@ class PrismaClient:
|
||||||
async def get_data(
|
async def get_data(
|
||||||
self,
|
self,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
expires: Optional[Any] = None,
|
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
|
table_name: Optional[Literal["user", "key", "config"]] = None,
|
||||||
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
print_verbose("PrismaClient: get_data")
|
print_verbose("PrismaClient: get_data")
|
||||||
|
|
||||||
response = None
|
response: Any = None
|
||||||
if token is not None:
|
if token is not None or (table_name is not None and table_name == "key"):
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
|
if token is not None:
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
print_verbose("PrismaClient: find_unique")
|
print_verbose("PrismaClient: find_unique")
|
||||||
|
if query_type == "find_unique":
|
||||||
response = await self.db.litellm_verificationtoken.find_unique(
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||||||
where={"token": hashed_token}
|
where={"token": hashed_token}
|
||||||
)
|
)
|
||||||
print_verbose(f"PrismaClient: response={response}")
|
elif query_type == "find_all" and user_id is not None:
|
||||||
if response:
|
response = await self.db.litellm_verificationtoken.find_many(
|
||||||
# Token exists, now check expiration.
|
where={"user_id": user_id}
|
||||||
if response.expires is not None and expires is not None:
|
|
||||||
if response.expires >= expires:
|
|
||||||
# Token exists and is not expired.
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
# Token exists but is expired.
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="expired user key",
|
|
||||||
)
|
)
|
||||||
|
print_verbose(f"PrismaClient: response={response}")
|
||||||
|
if response is not None:
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
# Token does not exist.
|
# Token does not exist.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue