feat(proxy_server.py): add in-memory caching for user api keys

This commit is contained in:
Krrish Dholakia 2023-11-23 13:19:48 -08:00
parent 8291f239a4
commit 8c4e8d6c62
2 changed files with 54 additions and 41 deletions

View file

@ -81,6 +81,7 @@ def generate_feedback_box():
print()
import litellm
from litellm.caching import DualCache
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends
from fastapi.routing import APIRouter
@ -124,7 +125,7 @@ log_file = "api_log.json"
worker_config = None
master_key = None
prisma_client = None
config_cache: dict = {}
user_api_key_cache = DualCache()
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -145,7 +146,7 @@ def usage_telemetry(
).start()
async def user_api_key_auth(request: Request):
global master_key, prisma_client, config_cache, llm_model_list
global master_key, prisma_client, llm_model_list
if master_key is None:
return
try:
@ -157,24 +158,27 @@ async def user_api_key_auth(request: Request):
if route == "/key/generate" and api_key != master_key:
raise Exception(f"If master key is set, only master key can be used to generate new keys")
if api_key in config_cache:
llm_model_list = config_cache[api_key].get("model_list", [])
return
if prisma_client:
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key)
if valid_token is None:
## check db
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
## save to cache for 60s
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
else:
print(f"API Key Cache Hit!")
if valid_token:
litellm.model_alias_map = valid_token.aliases
config = valid_token.config
if config != {}:
model_list = config.get("model_list", [])
llm_model_list = model_list
config_cache[api_key] = config
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return
@ -579,6 +583,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
print(f"receiving data: {data}")
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args

View file

@ -251,33 +251,39 @@ class Router:
def function_with_retries(self, *args, **kwargs):
# we'll backoff exponentially with each retry
self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}")
backoff_factor = 1
original_function = kwargs.pop("original_function")
num_retries = kwargs.pop("num_retries")
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except Exception as e:
for current_attempt in range(num_retries):
num_retries -= 1 # decrement the number of retries
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except openai.RateLimitError as e:
if num_retries > 0:
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
except openai.RateLimitError as e:
if num_retries > 0:
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
else:
raise e
except Exception as e:
# for any other exception types, immediately retry
if num_retries > 0:
pass
else:
raise e
num_retries -= 1 # decrement the number of retries
# increase backoff factor for next run
backoff_factor *= 2
else:
raise e
except Exception as e:
# for any other exception types, immediately retry
if num_retries > 0:
pass
else:
raise e
### COMPLETION + EMBEDDING FUNCTIONS
@ -289,12 +295,14 @@ class Router:
Example usage:
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
"""
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs["num_retries"] = self.num_retries
return self.function_with_retries(**kwargs)
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs["num_retries"] = self.num_retries
return self.function_with_retries(**kwargs)
except Exception as e:
raise e
def _completion(
self,