diff --git a/litellm/caching.py b/litellm/caching.py index d6a7ddacd..923a3031b 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -1,3 +1,4 @@ +import litellm def get_prompt(*args, **kwargs): # make this safe checks, it should not throw any exceptions if len(args) > 1: @@ -29,8 +30,9 @@ class InMemoryCache(): self.cache_dict = {} def set_cache(self, key, value): - #print("in set cache for inmem") + print("in set cache for inmem") self.cache_dict[key] = value + print(self.cache_dict) def get_cache(self, key): #print("in get cache for inmem") @@ -46,6 +48,8 @@ class Cache(): self.cache = RedisCache(type, host, port, password) if type == "local": self.cache = InMemoryCache() + litellm.input_callback.append("cache") + litellm.success_callback.append("cache") def get_cache_key(self, *args, **kwargs): prompt = get_prompt(*args, **kwargs) @@ -71,8 +75,11 @@ class Cache(): def add_cache(self, result, *args, **kwargs): try: + # print("adding to cache", result) cache_key = self.get_cache_key(*args, **kwargs) + # print(cache_key) if cache_key is not None: + # print("adding to cache", cache_key, result) self.cache.set_cache(cache_key, result) except: pass diff --git a/litellm/main.py b/litellm/main.py index 859486ae6..354c5591a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -196,7 +196,7 @@ def completion( engine=model, messages=messages, **optional_params ) if "stream" in optional_params and optional_params["stream"] == True: - response = CustomStreamWrapper(response, model) + response = CustomStreamWrapper(response, model, logging_obj=logging) return response ## LOGGING logging.post_call( @@ -254,7 +254,7 @@ def completion( model=model, messages=messages, **optional_params ) if "stream" in optional_params and optional_params["stream"] == True: - response = CustomStreamWrapper(response, model) + response = CustomStreamWrapper(response, model, logging_obj=logging) return response ## LOGGING logging.post_call( diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 21546e40c..949ae0c26 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -153,15 +153,19 @@ def test_embedding_caching(): # # test caching with streaming -# messages = [{"role": "user", "content": "draft a 2 pg legal document on applying to litellm"}] +# messages = [{"role": "user", "content": "hello gm who are u"}] # def test_caching_v2_stream(): # try: # litellm.cache = Cache() +# # litellm.token="ishaan@berri.ai" # response1 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) # for chunk in response1: -# print(chunk) -# response1_id = chunk['id'] - +# # +# pass +# # print("chunk") +# pass +# # response1_id = chunk['id'] + # # response2 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) # # for chunk in response2: # # #print(chunk) diff --git a/litellm/utils.py b/litellm/utils.py index fca845728..26b37b694 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -70,7 +70,7 @@ last_fetched_at_keys = None class Message(OpenAIObject): - def __init__(self, content="default", role="assistant", **params): + def __init__(self, content=" ", role="assistant", **params): super(Message, self).__init__(**params) self.content = content self.role = role @@ -285,6 +285,26 @@ class Logging: call_type = self.call_type, stream = self.stream ) + if callback == "cache": + try: + #print("in cache callback2", self.stream) + #print(original_response) + #print(self.model_call_details) + + if litellm.cache != None: + if self.litellm_params["stream_response"] == None: + self.litellm_params["stream_response"] = ModelResponse() + else: + #self.litellm_call_id["stream_response"]["id"] = self.litellm_params["litellm_call_id"] + self.litellm_params["stream_response"]["choices"][0]["message"]["content"] += original_response + #print("cache is not none") + # convert original_response to format of Model Object + # Set the model + litellm.cache.add_cache(self.litellm_params["stream_response"], **self.model_call_details) + #print(self.litellm_params["stream_response"]) + except Exception as e: + print("got exception") + print(e) except: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while post-call logging with integrations {traceback.format_exc()}" @@ -624,6 +644,7 @@ def get_litellm_params( "custom_api_base": custom_api_base, "litellm_call_id": litellm_call_id, "model_alias_map": model_alias_map, + "stream_response": None } return litellm_params @@ -1576,7 +1597,10 @@ class CustomStreamWrapper: return chunk["choices"][0]["text"] except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + + def handle_openai_chat_completion_chunk(self, chunk): + return chunk["choices"][0]["delta"]["content"] + def handle_baseten_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) @@ -1593,44 +1617,47 @@ class CustomStreamWrapper: return "" def __next__(self): - completion_obj = {"role": "assistant", "content": ""} - if self.model in litellm.anthropic_models: - chunk = next(self.completion_stream) - completion_obj["content"] = self.handle_anthropic_chunk(chunk) - elif self.model == "replicate": - chunk = next(self.completion_stream) - completion_obj["content"] = chunk - elif ( - self.custom_llm_provider and self.custom_llm_provider == "together_ai" - ) or ("togethercomputer" in self.model): - chunk = next(self.completion_stream) - text_data = self.handle_together_ai_chunk(chunk) - if text_data == "": - return self.__next__() - completion_obj["content"] = text_data - elif self.model in litellm.cohere_models: - chunk = next(self.completion_stream) - completion_obj["content"] = chunk.text - elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": - chunk = next(self.completion_stream) - completion_obj["content"] = self.handle_huggingface_chunk(chunk) - elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming - chunk = next(self.completion_stream) - completion_obj["content"] = self.handle_baseten_chunk(chunk) - elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming - chunk = next(self.completion_stream) - completion_obj["content"] = self.handle_ai21_chunk(chunk) - elif self.model in litellm.open_ai_text_completion_models: - chunk = next(self.completion_stream) - completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk) - else: # openai chat/azure models - chunk = next(self.completion_stream) - return chunk + try: + completion_obj = {"role": "assistant", "content": ""} + if self.model in litellm.anthropic_models: + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_anthropic_chunk(chunk) + elif self.model == "replicate": + chunk = next(self.completion_stream) + completion_obj["content"] = chunk + elif ( + self.custom_llm_provider and self.custom_llm_provider == "together_ai" + ) or ("togethercomputer" in self.model): + chunk = next(self.completion_stream) + text_data = self.handle_together_ai_chunk(chunk) + if text_data == "": + return self.__next__() + completion_obj["content"] = text_data + elif self.model in litellm.cohere_models: + chunk = next(self.completion_stream) + completion_obj["content"] = chunk.text + elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_huggingface_chunk(chunk) + elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_baseten_chunk(chunk) + elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_ai21_chunk(chunk) + elif self.model in litellm.open_ai_text_completion_models: + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk) + else: # openai chat/azure models + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk) - # LOGGING - self.logging_obj.post_call(completion_obj["content"]) - # return this for all models - return {"choices": [{"delta": completion_obj}]} + # LOGGING + self.logging_obj.post_call(completion_obj["content"]) + # return this for all models + return {"choices": [{"delta": completion_obj}]} + except: + raise StopIteration ########## Reading Config File ############################