diff --git a/litellm/caching.py b/litellm/caching.py index ceba188355..dc4c339f27 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -262,9 +262,6 @@ class Cache: cache_key = self.get_cache_key(*args, **kwargs) if cache_key is not None: cached_result = self.cache.get_cache(cache_key) - if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True: - # if streaming is true and we got a cache hit, return a generator - return self.generate_streaming_content(cached_result["choices"][0]['message']['content']) return cached_result except Exception as e: logging.debug(f"An exception occurred: {traceback.format_exc()}") diff --git a/litellm/utils.py b/litellm/utils.py index 8377051952..0c1ebc6ac9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -817,19 +817,6 @@ class Logging: if callback == "api_manager": print_verbose("reaches api manager for updating model cost") litellm.apiManager.update_cost(completion_obj=result, user=self.user) - if callback == "cache": - if litellm.cache != None and self.model_call_details.get('optional_params', {}).get('stream', False) == True: - litellm_call_id = self.litellm_params["litellm_call_id"] - if litellm_call_id in self.litellm_params["stream_response"]: - # append for the given call_id - if self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] == "default": - self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] = result["content"] # handle first try - else: - self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] += result["content"] - else: # init a streaming response for this call id - new_model_response = ModelResponse(choices=[Choices(message=Message(content="default"))]) - self.litellm_params["stream_response"][litellm_call_id] = new_model_response - litellm.cache.add_cache(self.litellm_params["stream_response"][litellm_call_id], **self.model_call_details) if callback == "promptlayer": print_verbose("reaches promptlayer for logging!") promptLayerLogger.log_event( @@ -937,6 +924,16 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) + if callback == "cache": + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + kwargs = self.model_call_details + if self.stream: + if "complete_streaming_response" not in kwargs: + return + else: + result = kwargs["complete_streaming_response"] + # only add to cache once we have a complete streaming response + litellm.cache.add_cache(result, **kwargs) if callback == "traceloop": deep_copy = {} for k, v in self.model_call_details.items(): @@ -1401,7 +1398,6 @@ def client(original_function): print_verbose(f"Checking Cache") cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: - print_verbose(f"Cache Hit!") if "detail" in cached_result: # implies an error occurred pass @@ -1409,7 +1405,7 @@ def client(original_function): call_type = original_function.__name__ print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}") if call_type == CallTypes.completion.value and isinstance(cached_result, dict): - return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) + return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse(), stream = kwargs.get("stream", False)) elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict): return convert_to_model_response_object(response_object=cached_result, response_type="embedding") else: @@ -3414,12 +3410,54 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k exception_logging(logger_fn=user_logger_fn, exception=e) pass +def convert_to_streaming_response(response_object: Optional[dict]=None): + # used for yielding Cache hits when stream == True + if response_object is None: + raise Exception("Error in response object format") -def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion"): + model_response_object = ModelResponse(stream=True) + choice_list=[] + for idx, choice in enumerate(response_object["choices"]): + delta = Delta( + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None) + ) + finish_reason = choice.get("finish_reason", None) + if finish_reason == None: + # gpt-4 vision can return 'finish_reason' or 'finish_details' + finish_reason = choice.get("finish_details") + choice = StreamingChoices(finish_reason=finish_reason, index=idx, delta=delta) + choice_list.append(choice) + model_response_object.choices = choice_list + + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + + if "id" in response_object: + model_response_object.id = response_object["id"] + + if "system_fingerprint" in response_object: + model_response_object.system_fingerprint = response_object["system_fingerprint"] + + if "model" in response_object: + model_response_object.model = response_object["model"] + yield model_response_object + + +def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False): try: if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)): if response_object is None or model_response_object is None: raise Exception("Error in response object format") + if stream == True: + # for returning cached responses, we need to yield a generator + return convert_to_streaming_response( + response_object=response_object + ) choice_list=[] for idx, choice in enumerate(response_object["choices"]): message = Message(