fix(proxy/utils.py): return different exceptions if key is invalid vs. expired

https://github.com/BerriAI/litellm/issues/1230
This commit is contained in:
Krrish Dholakia 2023-12-25 10:29:44 +05:30
parent a6d1c7e221
commit 9f79f75635
3 changed files with 19 additions and 13 deletions

View file

@ -6,6 +6,7 @@ from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
class MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self):
pass

View file

@ -1,7 +1,7 @@
import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, List
import secrets, subprocess
import hashlib, uuid
@ -271,7 +271,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if valid_token is None:
## check db
print(f"api key: {api_key}")
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc))
print(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:

View file

@ -6,6 +6,8 @@ from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException, status
def print_verbose(print_statement):
if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa
@ -173,22 +175,25 @@ class PrismaClient:
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
}
)
return response
if response:
# Token exists, now check expiration.
if response.expires is not None and expires is not None:
if response.expires >= expires:
# Token exists and is not expired.
return response
else:
# Token exists but is expired.
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key")
else:
# Token does not exist.
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key")
elif user_id is not None:
response = await self.db.litellm_usertable.find_first( # type: ignore
response = await self.db.litellm_usertable.find_unique( # type: ignore
where={
"user_id": user_id,
}