feat(proxy_server.py): add in-memory caching for user api keys

This commit is contained in:
Krrish Dholakia 2023-11-23 13:19:48 -08:00
parent 8291f239a4
commit 8c4e8d6c62
2 changed files with 54 additions and 41 deletions

View file

@ -81,6 +81,7 @@ def generate_feedback_box():
print()
import litellm
from litellm.caching import DualCache
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends
from fastapi.routing import APIRouter
@ -124,7 +125,7 @@ log_file = "api_log.json"
worker_config = None
master_key = None
prisma_client = None
config_cache: dict = {}
user_api_key_cache = DualCache()
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -145,7 +146,7 @@ def usage_telemetry(
).start()
async def user_api_key_auth(request: Request):
global master_key, prisma_client, config_cache, llm_model_list
global master_key, prisma_client, llm_model_list
if master_key is None:
return
try:
@ -157,24 +158,27 @@ async def user_api_key_auth(request: Request):
if route == "/key/generate" and api_key != master_key:
raise Exception(f"If master key is set, only master key can be used to generate new keys")
if api_key in config_cache:
llm_model_list = config_cache[api_key].get("model_list", [])
return
if prisma_client:
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key)
if valid_token is None:
## check db
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
## save to cache for 60s
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
else:
print(f"API Key Cache Hit!")
if valid_token:
litellm.model_alias_map = valid_token.aliases
config = valid_token.config
if config != {}:
model_list = config.get("model_list", [])
llm_model_list = model_list
config_cache[api_key] = config
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return
@ -579,6 +583,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
print(f"receiving data: {data}")
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args