From 61fc76a8c4ac8224f3122154774f8847b6c07b91 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 23 Nov 2023 11:13:24 -0800 Subject: [PATCH] fix(router.py): fix caching for tracking cooldowns + usage --- docs/my-website/sidebars.js | 2 +- litellm/caching.py | 110 ++++++++++++++++++++++++++--------- litellm/proxy/proxy_cli.py | 2 +- litellm/router.py | 95 +++++++++++++++++------------- litellm/tests/test_router.py | 14 +++-- 5 files changed, 148 insertions(+), 75 deletions(-) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index ddae583c18..95797c1cd0 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -89,6 +89,7 @@ const sidebars = { "routing", "rules", "set_keys", + "budget_manager", "completion/token_usage", { type: 'category', @@ -157,7 +158,6 @@ const sidebars = { label: 'Extras', items: [ 'extras/contributing', - "budget_manager", "proxy_server", { type: "category", diff --git a/litellm/caching.py b/litellm/caching.py index a49ed42dfb..f87c0f122a 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -23,6 +23,9 @@ def get_prompt(*args, **kwargs): return prompt return None +def print_verbose(print_statement): + if litellm.set_verbose: + print(print_statement) # noqa class BaseCache: def set_cache(self, key, value, **kwargs): @@ -32,6 +35,34 @@ class BaseCache: raise NotImplementedError +class InMemoryCache(BaseCache): + def __init__(self): + # if users don't provider one, use the default litellm cache + self.cache_dict = {} + self.ttl_dict = {} + + def set_cache(self, key, value, **kwargs): + self.cache_dict[key] = value + if "ttl" in kwargs: + self.ttl_dict[key] = time.time() + kwargs["ttl"] + + def get_cache(self, key, **kwargs): + if key in self.cache_dict: + if key in self.ttl_dict: + if time.time() > self.ttl_dict[key]: + self.cache_dict.pop(key, None) + return None + original_cached_response = self.cache_dict[key] + try: + cached_response = json.loads(original_cached_response) + except: + cached_response = original_cached_response + if isinstance(cached_response, dict): + cached_response['cache'] = True # set cache-hit flag to True + return cached_response + return None + + class RedisCache(BaseCache): def __init__(self, host, port, password): import redis @@ -65,7 +96,58 @@ class RedisCache(BaseCache): traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) +class DualCache(BaseCache): + """ + This updates both Redis and an in-memory cache simultaneously. + When data is updated or inserted, it is written to both the in-memory cache + Redis. + This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. + """ + def __init__(self, in_memory_cache: InMemoryCache =None, redis_cache: RedisCache =None) -> None: + super().__init__() + # If in_memory_cache is not provided, use the default InMemoryCache + self.in_memory_cache = in_memory_cache or InMemoryCache() + # If redis_cache is not provided, use the default RedisCache + self.redis_cache = redis_cache + + def set_cache(self, key, value, **kwargs): + # Update both Redis and in-memory cache + try: + print_verbose(f"set cache: key: {key}; value: {value}") + if self.in_memory_cache is not None: + self.in_memory_cache.set_cache(key, value, **kwargs) + if self.redis_cache is not None: + self.redis_cache.set_cache(key, value, **kwargs) + except Exception as e: + print_verbose(e) + + def get_cache(self, key, **kwargs): + # Try to fetch from in-memory cache first + try: + print_verbose(f"get cache: cache key: {key}") + result = None + if self.in_memory_cache is not None: + in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) + + if in_memory_result is not None: + result = in_memory_result + + if self.redis_cache is not None: + # If not found in in-memory cache, try fetching from Redis + redis_result = self.redis_cache.get_cache(key, **kwargs) + + if redis_result is not None: + # Update in-memory cache with the value from Redis + self.in_memory_cache.set_cache(key, redis_result, **kwargs) + + result = redis_result + + print_verbose(f"get cache: cache result: {result}") + return result + except Exception as e: + traceback.print_exc() + +#### DEPRECATED #### class HostedCache(BaseCache): def set_cache(self, key, value, **kwargs): if "ttl" in kwargs: @@ -91,33 +173,7 @@ class HostedCache(BaseCache): return cached_response -class InMemoryCache(BaseCache): - def __init__(self): - # if users don't provider one, use the default litellm cache - self.cache_dict = {} - self.ttl_dict = {} - - def set_cache(self, key, value, **kwargs): - self.cache_dict[key] = value - if "ttl" in kwargs: - self.ttl_dict[key] = time.time() + kwargs["ttl"] - - def get_cache(self, key, **kwargs): - if key in self.cache_dict: - if key in self.ttl_dict: - if time.time() > self.ttl_dict[key]: - self.cache_dict.pop(key, None) - return None - original_cached_response = self.cache_dict[key] - try: - cached_response = json.loads(original_cached_response) - except: - cached_response = original_cached_response - cached_response['cache'] = True # set cache-hit flag to True - return cached_response - return None - - +#### LiteLLM.Completion Cache #### class Cache: def __init__( self, diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 6a649f5bf1..b518675dd8 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -194,7 +194,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers "role": "user", "content": "this is a test request, write a short poem" } - ]) + ], max_tokens=256) click.echo(f'\nLiteLLM: response from proxy {response}') print("\n Making streaming request to proxy") diff --git a/litellm/router.py b/litellm/router.py index b4745aa047..935d0329cc 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -11,6 +11,7 @@ from datetime import datetime from typing import Dict, List, Optional, Union, Literal import random, threading, time import litellm, openai +from litellm.caching import RedisCache, InMemoryCache, DualCache import logging, asyncio import inspect from openai import AsyncOpenAI @@ -46,6 +47,7 @@ class Router: num_retries: int = 0, timeout: float = 600, default_litellm_params = {}, # default params for Router.chat.completion.create + set_verbose: bool = False, routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: if model_list: @@ -57,7 +59,7 @@ class Router: self.deployment_latency_map[m["litellm_params"]["model"]] = 0 self.num_retries = num_retries - + self.set_verbose = set_verbose self.chat = litellm.Chat(params=default_litellm_params) self.default_litellm_params = default_litellm_params @@ -69,6 +71,7 @@ class Router: self._start_health_check_thread() ### CACHING ### + redis_cache = None if redis_host is not None and redis_port is not None and redis_password is not None: cache_config = { 'type': 'redis', @@ -76,6 +79,7 @@ class Router: 'port': redis_port, 'password': redis_password } + redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password) else: # use an in-memory cache cache_config = { "type": "local" @@ -83,7 +87,7 @@ class Router: if cache_responses: litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests self.cache_responses = cache_responses - self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing + self.cache = DualCache(redis_cache=redis_cache) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ## USAGE TRACKING ## if isinstance(litellm.success_callback, list): litellm.success_callback.append(self.deployment_callback) @@ -155,6 +159,10 @@ class Router: def get_model_names(self): return self.model_names + def print_verbose(self, print_statement): + if self.set_verbose: + print(f"LiteLLM.Router: {print_statement}") # noqa + def get_available_deployment(self, model: str, messages: Optional[List[Dict[str, str]]] = None, @@ -166,19 +174,18 @@ class Router: ### get all deployments ### filter out the deployments currently cooling down healthy_deployments = [m for m in self.model_list if m["model_name"] == model] - current_time = time.time() - iter = 0 deployments_to_remove = [] cooldown_deployments = self._get_cooldown_deployments() + self.print_verbose(f"cooldown deployments: {cooldown_deployments}") ### FIND UNHEALTHY DEPLOYMENTS for deployment in healthy_deployments: deployment_name = deployment["litellm_params"]["model"] if deployment_name in cooldown_deployments: deployments_to_remove.append(deployment) - iter += 1 ### FILTER OUT UNHEALTHY DEPLOYMENTS for deployment in deployments_to_remove: healthy_deployments.remove(deployment) + self.print_verbose(f"healthy deployments: {healthy_deployments}") if litellm.model_alias_map and model in litellm.model_alias_map: model = litellm.model_alias_map[ model @@ -245,42 +252,56 @@ class Router: def function_with_retries(self, *args, **kwargs): # we'll backoff exponentially with each retry backoff_factor = 1 - original_exception = kwargs.pop("original_exception") original_function = kwargs.pop("original_function") - for current_attempt in range(self.num_retries): - self.num_retries -= 1 # decrement the number of retries + num_retries = kwargs.pop("num_retries") + for current_attempt in range(num_retries): + self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) return response except openai.RateLimitError as e: - # on RateLimitError we'll wait for an exponential time before trying again - time.sleep(backoff_factor) - - # increase backoff factor for next run - backoff_factor *= 2 - - except openai.APIError as e: - # on APIError we immediately retry without any wait, change this if necessary - pass + if num_retries > 0: + # on RateLimitError we'll wait for an exponential time before trying again + time.sleep(backoff_factor) + # increase backoff factor for next run + backoff_factor *= 2 + else: + raise e + except Exception as e: - # for any other exception types, don't retry - raise e + # for any other exception types, immediately retry + if num_retries > 0: + pass + else: + raise e + num_retries -= 1 # decrement the number of retries ### COMPLETION + EMBEDDING FUNCTIONS def completion(self, model: str, messages: List[Dict[str, str]], - is_retry: Optional[bool] = False, - is_fallback: Optional[bool] = False, **kwargs): """ Example usage: response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] """ + + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["original_function"] = self._completion + kwargs["num_retries"] = self.num_retries + return self.function_with_retries(**kwargs) + + def _completion( + self, + model: str, + messages: List[Dict[str, str]], + **kwargs): + try: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, messages=messages) @@ -288,18 +309,11 @@ class Router: for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params data[k] = v + + self.print_verbose(f"completion model: {data['model']}") return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) except Exception as e: - if self.num_retries > 0: - kwargs["model"] = model - kwargs["messages"] = messages - kwargs["original_exception"] = e - kwargs["original_function"] = self.completion - return self.function_with_retries(**kwargs) - else: - raise e - - + raise e async def acompletion(self, model: str, messages: List[Dict[str, str]], @@ -427,8 +441,9 @@ class Router: current_minute = datetime.now().strftime("%H-%M") # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls - cached_value = self.cache.get_cache(cache_key=cooldown_key) + cached_value = self.cache.get_cache(key=cooldown_key) + self.print_verbose(f"adding {deployment} to cooldown models") # update value try: if deployment in cached_value: @@ -436,12 +451,11 @@ class Router: else: cached_value = cached_value + [deployment] # save updated value - self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60) + self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60) except: cached_value = [deployment] - # save updated value - self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60) + self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60) def _get_cooldown_deployments(self): """ @@ -454,8 +468,9 @@ class Router: # ---------------------- # Return cooldown models # ---------------------- - cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or [] + cooldown_models = self.cache.get_cache(key=cooldown_key) or [] + self.print_verbose(f"retrieve cooldown models: {cooldown_models}") return cooldown_models def get_usage_based_available_deployment(self, @@ -522,21 +537,21 @@ class Router: # ------------ # Return usage # ------------ - tpm = self.cache.get_cache(cache_key=tpm_key) or 0 - rpm = self.cache.get_cache(cache_key=rpm_key) or 0 + tpm = self.cache.get_cache(key=tpm_key) or 0 + rpm = self.cache.get_cache(key=rpm_key) or 0 return int(tpm), int(rpm) def increment(self, key: str, increment_value: int): # get value - cached_value = self.cache.get_cache(cache_key=key) + cached_value = self.cache.get_cache(key=key) # update value try: cached_value = cached_value + increment_value except: cached_value = increment_value # save updated value - self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds) + self.cache.set_cache(value=cached_value, key=key, ttl=self.default_cache_time_seconds) def _set_deployment_usage( self, diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 84765448c4..ff5bbb0cfd 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -18,7 +18,7 @@ load_dotenv() def test_multiple_deployments(): import concurrent, time - litellm.set_verbose=True + litellm.set_verbose=False futures = {} model_list = [{ # list of model deployments "model_name": "gpt-3.5-turbo", # openai model name @@ -58,6 +58,7 @@ def test_multiple_deployments(): redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT")), routing_strategy="simple-shuffle", + set_verbose=False, num_retries=1) # type: ignore # router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT"))) # type: ignore kwargs = { @@ -81,12 +82,13 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of } results = [] - - for _ in range(2): - print(f"starting!!!") - response = router.completion(**kwargs) - results.append(response) + try: + for _ in range(3): + response = router.completion(**kwargs) + results.append(response) + except Exception as e: + raise e # print(len(results)) # with ThreadPoolExecutor(max_workers=100) as executor: