fix(proxy/utils.py): return different exceptions if key is invalid vs. expired

https://github.com/BerriAI/litellm/issues/1230
This commit is contained in:
Krrish Dholakia 2023-12-25 10:29:44 +05:30
parent a6d1c7e221
commit 9f79f75635
3 changed files with 19 additions and 13 deletions

View file

@ -6,6 +6,7 @@ from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
class MaxParallelRequestsHandler(CustomLogger): class MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
pass pass

View file

@ -1,7 +1,7 @@
import sys, os, platform, time, copy, re, asyncio, inspect import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Optional, List from typing import Optional, List
import secrets, subprocess import secrets, subprocess
import hashlib, uuid import hashlib, uuid
@ -271,7 +271,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if valid_token is None: if valid_token is None:
## check db ## check db
print(f"api key: {api_key}") print(f"api key: {api_key}")
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc))
print(f"valid token from prisma: {valid_token}") print(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None: elif valid_token is not None:

View file

@ -6,6 +6,8 @@ from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException, status
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa print(f"LiteLLM Proxy: {print_statement}") # noqa
@ -173,22 +175,25 @@ class PrismaClient:
hashed_token = token hashed_token = token
if token.startswith("sk-"): if token.startswith("sk-"):
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={ where={
"token": hashed_token "token": hashed_token
} }
) )
if response:
# Token exists, now check expiration.
if response.expires is not None and expires is not None:
if response.expires >= expires:
# Token exists and is not expired.
return response return response
else:
# Token exists but is expired.
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key")
else:
# Token does not exist.
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key")
elif user_id is not None: elif user_id is not None:
response = await self.db.litellm_usertable.find_first( # type: ignore response = await self.db.litellm_usertable.find_unique( # type: ignore
where={ where={
"user_id": user_id, "user_id": user_id,
} }