diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 39cc73751..98ee231b6 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -6,6 +6,7 @@ from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException class MaxParallelRequestsHandler(CustomLogger): + user_api_key_cache = None # Class variables or attributes def __init__(self): pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3e4e7f25d..0264ad2f9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1,7 +1,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect import threading, ast import shutil, random, traceback, requests -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional, List import secrets, subprocess import hashlib, uuid @@ -271,7 +271,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap if valid_token is None: ## check db print(f"api key: {api_key}") - valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) + valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)) print(f"valid token from prisma: {valid_token}") user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) elif valid_token is not None: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 2dc62d664..980b518d0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -6,6 +6,8 @@ from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException, status + def print_verbose(print_statement): if litellm.set_verbose: print(f"LiteLLM Proxy: {print_statement}") # noqa @@ -173,22 +175,25 @@ class PrismaClient: hashed_token = token if token.startswith("sk-"): hashed_token = self.hash_token(token=token) - if expires: - response = await self.db.litellm_verificationtoken.find_first( - where={ - "token": hashed_token, - "expires": {"gte": expires} # Check if the token is not expired - } - ) - else: - response = await self.db.litellm_verificationtoken.find_unique( + response = await self.db.litellm_verificationtoken.find_unique( where={ "token": hashed_token } ) - return response + if response: + # Token exists, now check expiration. + if response.expires is not None and expires is not None: + if response.expires >= expires: + # Token exists and is not expired. + return response + else: + # Token exists but is expired. + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key") + else: + # Token does not exist. + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key") elif user_id is not None: - response = await self.db.litellm_usertable.find_first( # type: ignore + response = await self.db.litellm_usertable.find_unique( # type: ignore where={ "user_id": user_id, }