forked from phoenix/litellm-mirror
feat(proxy_server.py): add in-memory caching for user api keys
This commit is contained in:
parent
8291f239a4
commit
8c4e8d6c62
2 changed files with 54 additions and 41 deletions
|
@ -81,6 +81,7 @@ def generate_feedback_box():
|
|||
print()
|
||||
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends
|
||||
from fastapi.routing import APIRouter
|
||||
|
@ -124,7 +125,7 @@ log_file = "api_log.json"
|
|||
worker_config = None
|
||||
master_key = None
|
||||
prisma_client = None
|
||||
config_cache: dict = {}
|
||||
user_api_key_cache = DualCache()
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
celery_app_conn = None
|
||||
|
@ -145,7 +146,7 @@ def usage_telemetry(
|
|||
).start()
|
||||
|
||||
async def user_api_key_auth(request: Request):
|
||||
global master_key, prisma_client, config_cache, llm_model_list
|
||||
global master_key, prisma_client, llm_model_list
|
||||
if master_key is None:
|
||||
return
|
||||
try:
|
||||
|
@ -157,24 +158,27 @@ async def user_api_key_auth(request: Request):
|
|||
if route == "/key/generate" and api_key != master_key:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate new keys")
|
||||
|
||||
if api_key in config_cache:
|
||||
llm_model_list = config_cache[api_key].get("model_list", [])
|
||||
return
|
||||
|
||||
if prisma_client:
|
||||
valid_token = await prisma_client.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": api_key,
|
||||
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
if valid_token is None:
|
||||
## check db
|
||||
valid_token = await prisma_client.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": api_key,
|
||||
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
## save to cache for 60s
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
else:
|
||||
print(f"API Key Cache Hit!")
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
config = valid_token.config
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
config_cache[api_key] = config
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
return
|
||||
|
@ -579,6 +583,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
data = ast.literal_eval(body_str)
|
||||
except:
|
||||
data = json.loads(body_str)
|
||||
print(f"receiving data: {data}")
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue