From 5e1e8820b456eb980f9e88884243ae05e56a6aa0 Mon Sep 17 00:00:00 2001 From: seva Date: Mon, 30 Oct 2023 13:29:35 +0100 Subject: [PATCH 1/2] Router & Caching fixes: - Add optional TTL to Cache parameters - Fix tpm and rpm caching in Router --- litellm/caching.py | 100 +++++++++++++++++++++++++-------------------- litellm/router.py | 68 +++++++++++++----------------- 2 files changed, 85 insertions(+), 83 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 0e508e37e..457e0c296 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -11,6 +11,7 @@ import litellm import time import json + def get_prompt(*args, **kwargs): # make this safe checks, it should not throw any exceptions if len(args) > 1: @@ -23,81 +24,98 @@ def get_prompt(*args, **kwargs): return prompt return None -class RedisCache(): + +class BaseCache: + def set_cache(self, key, value, **kwargs): + raise NotImplementedError + + def get_cache(self, key, **kwargs): + raise NotImplementedError + + +class RedisCache(BaseCache): def __init__(self, host, port, password): import redis # if users don't provider one, use the default litellm cache self.redis_client = redis.Redis(host=host, port=port, password=password) - def set_cache(self, key, value): + def set_cache(self, key, value, **kwargs): + ttl = kwargs.get("ttl", None) try: - self.redis_client.set(key, str(value)) + self.redis_client.set(name=key, value=str(value), ex=ttl) except Exception as e: # NON blocking - notify users Redis is throwing an exception print("LiteLLM Caching: Got exception from REDIS: ", e) - def get_cache(self, key): + def get_cache(self, key, **kwargs): try: # TODO convert this to a ModelResponse object cached_response = self.redis_client.get(key) - if cached_response!=None: + if cached_response != None: # cached_response is in `b{} convert it to ModelResponse cached_response = cached_response.decode("utf-8") # Convert bytes to string cached_response = json.loads(cached_response) # Convert string to dictionary - cached_response['cache'] = True # set cache-hit flag to True + cached_response['cache'] = True # set cache-hit flag to True return cached_response except Exception as e: # NON blocking - notify users Redis is throwing an exception print("LiteLLM Caching: Got exception from REDIS: ", e) -class HostedCache(): - def set_cache(self, key, value): + +class HostedCache(BaseCache): + def set_cache(self, key, value, **kwargs): + if "ttl" in kwargs: + print("LiteLLM Caching: TTL is not supported for hosted cache!") # make a post request to api.litellm.ai/set_cache import requests url = f"https://api.litellm.ai/set_cache?key={key}&value={str(value)}" - requests.request("POST", url) # post request to set this in the hosted litellm cache + requests.request("POST", url) # post request to set this in the hosted litellm cache - def get_cache(self, key): + def get_cache(self, key, **kwargs): import requests url = f"https://api.litellm.ai/get_cache?key={key}" cached_response = requests.request("GET", url) cached_response = cached_response.text - if cached_response == "NONE": # api.litellm.ai returns "NONE" if it's not a cache hit - return None - if cached_response!=None: + if cached_response == "NONE": # api.litellm.ai returns "NONE" if it's not a cache hit + return None + if cached_response != None: try: cached_response = json.loads(cached_response) # Convert string to dictionary - cached_response['cache'] = True # set cache-hit flag to True + cached_response['cache'] = True # set cache-hit flag to True return cached_response except: return cached_response -class InMemoryCache(): + +class InMemoryCache(BaseCache): def __init__(self): # if users don't provider one, use the default litellm cache self.cache_dict = {} + self.ttl_dict = {} - def set_cache(self, key, value): - #print("in set cache for inmem") + def set_cache(self, key, value, **kwargs): self.cache_dict[key] = value - #print(self.cache_dict) + if "ttl" in kwargs: + self.ttl_dict[key] = time.time() + kwargs["ttl"] - def get_cache(self, key): - #print("in get cache for inmem") + def get_cache(self, key, **kwargs): if key in self.cache_dict: - #print("got a cache hit") + if key in self.ttl_dict: + if time.time() > self.ttl_dict[key]: + self.cache_dict.pop(key, None) + return None return self.cache_dict[key] - #print("got a cache miss") return None -class Cache(): + +class Cache: def __init__( - self, - type = "local", - host = None, - port = None, - password = None - ): + self, + type="local", + host=None, + port=None, + password=None + ): """ Initializes the cache based on the given type. @@ -151,9 +169,9 @@ class Cache(): def generate_streaming_content(self, content): chunk_size = 5 # Adjust the chunk size as needed for i in range(0, len(content), chunk_size): - yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i+chunk_size]}}]} + yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]} time.sleep(0.02) - + def get_cache(self, *args, **kwargs): """ Retrieves the cached result for the given arguments. @@ -166,16 +184,16 @@ class Cache(): The cached result if it exists, otherwise None. """ try: # never block execution - if "cache_key" in kwargs: + if "cache_key" in kwargs: cache_key = kwargs["cache_key"] - else: + else: cache_key = self.get_cache_key(*args, **kwargs) if cache_key is not None: cached_result = self.cache.get_cache(cache_key) if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True: # if streaming is true and we got a cache hit, return a generator - #print("cache hit and stream=True") - #print(cached_result) + # print("cache hit and stream=True") + # print(cached_result) return self.generate_streaming_content(cached_result["choices"][0]['message']['content']) return cached_result except: @@ -193,20 +211,14 @@ class Cache(): None """ try: - if "cache_key" in kwargs: + if "cache_key" in kwargs: cache_key = kwargs["cache_key"] - else: + else: cache_key = self.get_cache_key(*args, **kwargs) # print("adding to cache", cache_key, result) # print(cache_key) if cache_key is not None: # print("adding to cache", cache_key, result) - self.cache.set_cache(cache_key, result) + self.cache.set_cache(cache_key, result, **kwargs) except: pass - - - - - - diff --git a/litellm/router.py b/litellm/router.py index c4997e3d5..7e4b3d54a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -24,6 +24,8 @@ class Router: """ model_names: List = [] cache_responses: bool = False + default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour + def __init__(self, model_list: Optional[list] = None, redis_host: Optional[str] = None, @@ -133,7 +135,10 @@ class Router: Function LiteLLM submits a callback to after a successful completion. Purpose of this is ti update TPM/RPM usage per model """ - model_name = kwargs.get('model', None) # i.e. azure/gpt35turbo + model_name = kwargs.get('model', None) # i.e. gpt35turbo + custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure + if custom_llm_provider: + model_name = f"{custom_llm_provider}/{model_name}" total_tokens = completion_response['usage']['total_tokens'] self._set_deployment_usage(model_name, total_tokens) @@ -150,17 +155,9 @@ class Router: if item["model_name"] == model: potential_deployments.append(item) - # set first model as current model + # set first model as current model to calculate token count deployment = potential_deployments[0] - - # get model tpm, rpm limits - tpm = deployment["tpm"] - rpm = deployment["rpm"] - - # get deployment current usage - current_tpm, current_rpm = self._get_deployment_usage(deployment_name=deployment["litellm_params"]["model"]) - # get encoding if messages: token_count = litellm.token_counter(model=deployment["model_name"], messages=messages) @@ -171,29 +168,27 @@ class Router: input_text = input token_count = litellm.token_counter(model=deployment["model_name"], text=input_text) - # if at model limit, return lowest used - if current_tpm + token_count > tpm or current_rpm + 1 >= rpm: - # ----------------------- - # Find lowest used model - # ---------------------- - lowest_tpm = float('inf') - deployment = None + # ----------------------- + # Find lowest used model + # ---------------------- + lowest_tpm = float("inf") + deployment = None - # Go through all the models to get tpm, rpm - for item in potential_deployments: - item_tpm, item_rpm = self._get_deployment_usage(deployment_name=item["litellm_params"]["model"]) + # Go through all the models to get tpm, rpm + for item in potential_deployments: + item_tpm, item_rpm = self._get_deployment_usage(deployment_name=item["litellm_params"]["model"]) - if item_tpm == 0: - return item - elif item_tpm + token_count > item["tpm"] or item_rpm + 1 >= item["rpm"]: - continue - elif item_tpm < lowest_tpm: - lowest_tpm = item_tpm - deployment = item + if item_tpm == 0: + return item + elif item_tpm + token_count > item["tpm"] or item_rpm + 1 >= item["rpm"]: + continue + elif item_tpm < lowest_tpm: + lowest_tpm = item_tpm + deployment = item - # if none, raise exception - if deployment is None: - raise ValueError(f"No models available.") + # if none, raise exception + if deployment is None: + raise ValueError("No models available.") # return model return deployment @@ -212,26 +207,21 @@ class Router: # ------------ # Return usage # ------------ - tpm = self.cache.get_cache(tpm_key) - rpm = self.cache.get_cache(rpm_key) - - if tpm is None: - tpm = 0 - if rpm is None: - rpm = 0 + tpm = self.cache.get_cache(cache_key=tpm_key) or 0 + rpm = self.cache.get_cache(cache_key=rpm_key) or 0 return int(tpm), int(rpm) def increment(self, key: str, increment_value: int): # get value - cached_value = self.cache.get_cache(key) + cached_value = self.cache.get_cache(cache_key=key) # update value try: cached_value = cached_value + increment_value except: cached_value = increment_value # save updated value - self.cache.add_cache(result=cached_value, cache_key=key) + self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds) def _set_deployment_usage( self, From f4d65cf23d7ba2024d360e0f73554aaf26e39a86 Mon Sep 17 00:00:00 2001 From: seva Date: Mon, 30 Oct 2023 13:35:55 +0100 Subject: [PATCH 2/2] Add tests for cache with TTL --- litellm/tests/test_caching.py | 48 ++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 78c3c86a7..dd6a57dbd 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -1,4 +1,5 @@ import sys, os +import time import traceback from dotenv import load_dotenv @@ -36,7 +37,7 @@ def test_gpt_cache(): cache_key = last_content_without_prompt_val + data["model"] print("cache_key", cache_key) return cache_key - + cache.init(pre_func=pre_cache_func) cache.set_openai_key() @@ -46,12 +47,12 @@ def test_gpt_cache(): response2 = completion(model="gpt-3.5-turbo", messages=messages) response3 = completion(model="command-nightly", messages=messages) - if response1["choices"] != response2["choices"]: # same models should cache + if response1["choices"] != response2["choices"]: # same models should cache print(f"response1: {response1}") print(f"response2: {response2}") pytest.fail(f"Error occurred:") - if response3["choices"] == response2["choices"]: # different models, don't cache + if response3["choices"] == response2["choices"]: # different models, don't cache # if models are different, it should not return cached response print(f"response2: {response2}") print(f"response3: {response3}") @@ -124,9 +125,9 @@ def test_embedding_caching(): embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True) end_time = time.time() print(f"Embedding 2 response time: {end_time - start_time} seconds") - + litellm.cache = None - assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s + assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: print(f"embedding1: {embedding1}") print(f"embedding2: {embedding2}") @@ -178,14 +179,14 @@ def test_embedding_caching_azure(): ) end_time = time.time() print(f"Embedding 2 response time: {end_time - start_time} seconds") - + litellm.cache = None - assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s + assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: print(f"embedding1: {embedding1}") print(f"embedding2: {embedding2}") pytest.fail("Error occurred: Embedding caching failed") - + os.environ['AZURE_API_VERSION'] = api_version os.environ['AZURE_API_BASE'] = api_base os.environ['AZURE_API_KEY'] = api_key @@ -279,11 +280,11 @@ def test_redis_cache_completion(): def set_cache(key, value): local_cache[key] = value - + def get_cache(key): if key in local_cache: return local_cache[key] - + litellm.cache.cache.set_cache = set_cache litellm.cache.cache.get_cache = get_cache @@ -322,11 +323,11 @@ def test_custom_redis_cache_with_key(): def set_cache(key, value): local_cache[key] = value - + def get_cache(key): if key in local_cache: return local_cache[key] - + litellm.cache.cache.set_cache = set_cache litellm.cache.cache.get_cache = get_cache @@ -335,16 +336,16 @@ def test_custom_redis_cache_with_key(): response1 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True) response3 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=False) - + print(f"response1: {response1}") print(f"response2: {response2}") print(f"response3: {response3}") if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: - pytest.fail(f"Error occurred:") + pytest.fail(f"Error occurred:") litellm.cache = None -test_custom_redis_cache_with_key() +# test_custom_redis_cache_with_key() def test_hosted_cache(): litellm.cache = Cache(type="hosted") # use api.litellm.ai for caching @@ -364,3 +365,20 @@ def test_hosted_cache(): # test_hosted_cache() + +def test_redis_cache_with_ttl(): + cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + cache.add_cache(cache_key="test_key", result="test_value", ttl=1) + cached_value = cache.get_cache(cache_key="test_key") + assert cached_value == "test_value" + time.sleep(2) + assert cache.get_cache(cache_key="test_key") is None + + +def test_in_memory_cache_with_ttl(): + cache = Cache(type="local") + cache.add_cache(cache_key="test_key", result="test_value", ttl=1) + cached_value = cache.get_cache(cache_key="test_key") + assert cached_value == "test_value" + time.sleep(2) + assert cache.get_cache(cache_key="test_key") is None