diff --git a/litellm/caching.py b/litellm/caching.py index 7c04bab72..6681bde34 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -49,8 +49,10 @@ class Cache(): self.cache = RedisCache(type, host, port, password) if type == "local": self.cache = InMemoryCache() - litellm.input_callback.append("cache") - litellm.success_callback.append("cache") + if "cache" not in litellm.input_callback: + litellm.input_callback.append("cache") + if "cache" not in litellm.success_callback: + litellm.success_callback.append("cache") def get_cache_key(self, *args, **kwargs): prompt = get_prompt(*args, **kwargs) @@ -88,8 +90,9 @@ class Cache(): def add_cache(self, result, *args, **kwargs): try: - # print("adding to cache", result) + cache_key = self.get_cache_key(*args, **kwargs) + # print("adding to cache", cache_key, result) # print(cache_key) if cache_key is not None: # print("adding to cache", cache_key, result) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 580c6d86b..946c94ebb 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -127,7 +127,7 @@ embedding_large_text = """ small text """ * 5 -# test_caching_with_models() +# # test_caching_with_models() def test_embedding_caching(): import time litellm.cache = Cache() @@ -136,7 +136,7 @@ def test_embedding_caching(): embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed) end_time = time.time() print(f"Embedding 1 response time: {end_time - start_time} seconds") - + time.sleep(1) start_time = time.time() embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed) @@ -153,18 +153,64 @@ def test_embedding_caching(): # test caching with streaming -messages = [{"role": "user", "content": "tell me a story in 2 sentences"}] -def test_caching_v2_stream(): + +def test_caching_v2_stream_basic(): try: litellm.cache = Cache() # litellm.token="ishaan@berri.ai" + messages = [{"role": "user", "content": "tell me a story in 2 sentences"}] response1 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) + result_string = "" for chunk in response1: print(chunk) result_string+=chunk['choices'][0]['delta']['content'] # response1_id = chunk['id'] + print("current cache") + print(litellm.cache.cache.cache_dict) + + result2_string="" + import time + time.sleep(1) + response2 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) + for chunk in response2: + print(chunk) + result2_string+=chunk['choices'][0]['delta']['content'] + if result_string != result2_string: + print(result_string) + print(result2_string) + pytest.fail(f"Error occurred: Caching with streaming failed, strings diff") + litellm.cache = None + + except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + pytest.fail(f"Error occurred: {e}") + +# test_caching_v2_stream_basic() + +def test_caching_v2_stream(): + try: + litellm.cache = Cache() + # litellm.token="ishaan@berri.ai" + messages = [{"role": "user", "content": "tell me a story in 2 sentences"}] + response1 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) + + messages = [{"role": "user", "content": "tell me a chair"}] + response7 = completion(model="command-nightly", messages=messages) + messages = [{"role": "user", "content": "sing a song"}] + response8 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) + + result_string = "" + for chunk in response1: + print(chunk) + result_string+=chunk['choices'][0]['delta']['content'] + # response1_id = chunk['id'] + + print("current cache") + messages = [{"role": "user", "content": "tell me a story in 2 sentences"}] + print(litellm.cache.cache.cache_dict) + result2_string="" response2 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) for chunk in response2: @@ -174,6 +220,7 @@ def test_caching_v2_stream(): print(result_string) print(result2_string) pytest.fail(f"Error occurred: Caching with streaming failed, strings diff") + litellm.cache = None except Exception as e: print(f"error occurred: {traceback.format_exc()}") diff --git a/litellm/utils.py b/litellm/utils.py index c41610f5c..0d787949d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -87,7 +87,7 @@ class Choices(OpenAIObject): class ModelResponse(OpenAIObject): def __init__(self, choices=None, created=None, model=None, usage=None, **params): super(ModelResponse, self).__init__(**params) - self.choices = choices if choices else [Choices()] + self.choices = self.choices = choices if choices else [Choices(message=Message())] self.created = created self.model = model self.usage = ( @@ -271,7 +271,7 @@ class Logging: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) - + # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: try: @@ -287,17 +287,26 @@ class Logging: ) if callback == "cache": try: + # print("entering logger first time") + # print(self.litellm_params["stream_response"]) if litellm.cache != None and self.model_call_details.get('optional_params', {}).get('stream', False) == True: - if self.litellm_params["stream_response"] == None: - self.litellm_params["stream_response"] = ModelResponse() - else: - #self.litellm_call_id["stream_response"]["id"] = self.litellm_params["litellm_call_id"] - if self.litellm_params["stream_response"]["choices"][0]["message"]["content"] == "default": - self.litellm_params["stream_response"]["choices"][0]["message"]["content"] = original_response # handle first try + litellm_call_id = self.litellm_params["litellm_call_id"] + if litellm_call_id in self.litellm_params["stream_response"]: + # append for the given call_id + if self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] == "default": + self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] = original_response # handle first try else: - self.litellm_params["stream_response"]["choices"][0]["message"]["content"] += original_response - litellm.cache.add_cache(self.litellm_params["stream_response"], **self.model_call_details) + self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] += original_response + else: # init a streaming response for this call id + new_model_response = ModelResponse(choices=[Choices(message=Message(content="default"))]) + #print("creating new model response") + #print(new_model_response) + self.litellm_params["stream_response"][litellm_call_id] = new_model_response + #print("adding to cache for", litellm_call_id) + litellm.cache.add_cache(self.litellm_params["stream_response"][litellm_call_id], **self.model_call_details) except Exception as e: + # print("got exception") + # print(e) pass except: print_verbose( @@ -466,7 +475,6 @@ def client(original_function): # CRASH REPORTING TELEMETRY crash_reporting(*args, **kwargs) # INIT LOGGER - for user-specified integrations - print(f"len args: {len(args)}") model = args[0] if len(args) > 0 else kwargs["model"] call_type = original_function.__name__ if call_type == CallTypes.completion.value: @@ -638,7 +646,7 @@ def get_litellm_params( "custom_api_base": custom_api_base, "litellm_call_id": litellm_call_id, "model_alias_map": model_alias_map, - "stream_response": None + "stream_response": {} # litellm_call_id: ModelResponse Dict } return litellm_params diff --git a/pyproject.toml b/pyproject.toml index bf064d5e2..9c9c36656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.497" +version = "0.1.498" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"