From 8c4e8d6c62167b9b41797c0d9fd479361270918d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 23 Nov 2023 13:19:48 -0800 Subject: [PATCH] feat(proxy_server.py): add in-memory caching for user api keys --- litellm/proxy/proxy_server.py | 31 ++++++++++------- litellm/router.py | 64 ++++++++++++++++++++--------------- 2 files changed, 54 insertions(+), 41 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3068ffb5a4..09956660b1 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -81,6 +81,7 @@ def generate_feedback_box(): print() import litellm +from litellm.caching import DualCache litellm.suppress_debug_info = True from fastapi import FastAPI, Request, HTTPException, status, Depends from fastapi.routing import APIRouter @@ -124,7 +125,7 @@ log_file = "api_log.json" worker_config = None master_key = None prisma_client = None -config_cache: dict = {} +user_api_key_cache = DualCache() ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -145,7 +146,7 @@ def usage_telemetry( ).start() async def user_api_key_auth(request: Request): - global master_key, prisma_client, config_cache, llm_model_list + global master_key, prisma_client, llm_model_list if master_key is None: return try: @@ -157,24 +158,27 @@ async def user_api_key_auth(request: Request): if route == "/key/generate" and api_key != master_key: raise Exception(f"If master key is set, only master key can be used to generate new keys") - if api_key in config_cache: - llm_model_list = config_cache[api_key].get("model_list", []) - return - if prisma_client: - valid_token = await prisma_client.litellm_verificationtoken.find_first( - where={ - "token": api_key, - "expires": {"gte": datetime.utcnow()} # Check if the token is not expired - } - ) + ## check for cache hit (In-Memory Cache) + valid_token = user_api_key_cache.get_cache(key=api_key) + if valid_token is None: + ## check db + valid_token = await prisma_client.litellm_verificationtoken.find_first( + where={ + "token": api_key, + "expires": {"gte": datetime.utcnow()} # Check if the token is not expired + } + ) + ## save to cache for 60s + user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) + else: + print(f"API Key Cache Hit!") if valid_token: litellm.model_alias_map = valid_token.aliases config = valid_token.config if config != {}: model_list = config.get("model_list", []) llm_model_list = model_list - config_cache[api_key] = config print("\n new llm router model list", llm_model_list) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called return @@ -579,6 +583,7 @@ async def chat_completion(request: Request, model: Optional[str] = None): data = ast.literal_eval(body_str) except: data = json.loads(body_str) + print(f"receiving data: {data}") data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args diff --git a/litellm/router.py b/litellm/router.py index 935d0329cc..72533e06c2 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -251,33 +251,39 @@ class Router: def function_with_retries(self, *args, **kwargs): # we'll backoff exponentially with each retry + self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}") backoff_factor = 1 original_function = kwargs.pop("original_function") num_retries = kwargs.pop("num_retries") - for current_attempt in range(num_retries): - self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") - try: - # if the function call is successful, no exception will be raised and we'll break out of the loop - response = original_function(*args, **kwargs) - return response + try: + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = original_function(*args, **kwargs) + return response + except Exception as e: + for current_attempt in range(num_retries): + num_retries -= 1 # decrement the number of retries + self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") + try: + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = original_function(*args, **kwargs) + return response - except openai.RateLimitError as e: - if num_retries > 0: - # on RateLimitError we'll wait for an exponential time before trying again - time.sleep(backoff_factor) + except openai.RateLimitError as e: + if num_retries > 0: + # on RateLimitError we'll wait for an exponential time before trying again + time.sleep(backoff_factor) - # increase backoff factor for next run - backoff_factor *= 2 - else: - raise e - - except Exception as e: - # for any other exception types, immediately retry - if num_retries > 0: - pass - else: - raise e - num_retries -= 1 # decrement the number of retries + # increase backoff factor for next run + backoff_factor *= 2 + else: + raise e + + except Exception as e: + # for any other exception types, immediately retry + if num_retries > 0: + pass + else: + raise e ### COMPLETION + EMBEDDING FUNCTIONS @@ -289,12 +295,14 @@ class Router: Example usage: response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] """ - - kwargs["model"] = model - kwargs["messages"] = messages - kwargs["original_function"] = self._completion - kwargs["num_retries"] = self.num_retries - return self.function_with_retries(**kwargs) + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["original_function"] = self._completion + kwargs["num_retries"] = self.num_retries + return self.function_with_retries(**kwargs) + except Exception as e: + raise e def _completion( self,