forked from phoenix/litellm-mirror
fix(proxy/utils.py): return different exceptions if key is invalid vs. expired
https://github.com/BerriAI/litellm/issues/1230
This commit is contained in:
parent
a6d1c7e221
commit
9f79f75635
3 changed files with 19 additions and 13 deletions
|
@ -6,6 +6,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from fastapi import HTTPException
|
||||
|
||||
class MaxParallelRequestsHandler(CustomLogger):
|
||||
user_api_key_cache = None
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import sys, os, platform, time, copy, re, asyncio, inspect
|
||||
import threading, ast
|
||||
import shutil, random, traceback, requests
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List
|
||||
import secrets, subprocess
|
||||
import hashlib, uuid
|
||||
|
@ -271,7 +271,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
if valid_token is None:
|
||||
## check db
|
||||
print(f"api key: {api_key}")
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc))
|
||||
print(f"valid token from prisma: {valid_token}")
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
|
|
|
@ -6,6 +6,8 @@ from litellm.caching import DualCache
|
|||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||
|
@ -173,22 +175,25 @@ class PrismaClient:
|
|||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
if response:
|
||||
# Token exists, now check expiration.
|
||||
if response.expires is not None and expires is not None:
|
||||
if response.expires >= expires:
|
||||
# Token exists and is not expired.
|
||||
return response
|
||||
else:
|
||||
# Token exists but is expired.
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key")
|
||||
else:
|
||||
# Token does not exist.
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key")
|
||||
elif user_id is not None:
|
||||
response = await self.db.litellm_usertable.find_first( # type: ignore
|
||||
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||||
where={
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue