diff --git a/litellm/utils.py b/litellm/utils.py index 982a4f2a80..2b23d6a789 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1056,6 +1056,19 @@ class Logging: start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) for callback in litellm._async_success_callback: try: + if callback == "cache": + # set_cache once complete streaming response is built + print_verbose("async success_callback: reaches cache for logging!") + kwargs = self.model_call_details + if self.stream: + if "complete_streaming_response" not in kwargs: + print_verbose(f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n") + return + else: + print_verbose("async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache") + result = kwargs["complete_streaming_response"] + # only add to cache once we have a complete streaming response + litellm.cache.add_cache(result, **kwargs) if isinstance(callback, CustomLogger): # custom logger class print_verbose(f"Async success callbacks: CustomLogger") if self.stream: @@ -1599,7 +1612,12 @@ def client(original_function): print_verbose(f"Cache Hit!") call_type = original_function.__name__ if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict): - return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) + if kwargs.get("stream", False) == True: + return convert_to_streaming_response_async( + response_object=cached_result, + ) + else: + return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) else: return cached_result # MODEL CALL @@ -3494,6 +3512,69 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k exception_logging(logger_fn=user_logger_fn, exception=e) pass +async def convert_to_streaming_response_async(response_object: Optional[dict]=None): + """ + Asynchronously converts a response object to a streaming response. + + Args: + response_object (Optional[dict]): The response object to be converted. Defaults to None. + + Raises: + Exception: If the response object is None. + + Yields: + ModelResponse: The converted streaming response object. + + Returns: + None + """ + if response_object is None: + raise Exception("Error in response object format") + + model_response_object = ModelResponse(stream=True) + + if model_response_object is None: + raise Exception("Error in response creating model response object") + + choice_list = [] + + for idx, choice in enumerate(response_object["choices"]): + delta = Delta( + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None) + ) + finish_reason = choice.get("finish_reason", None) + + if finish_reason is None: + finish_reason = choice.get("finish_details") + + choice = StreamingChoices(finish_reason=finish_reason, index=idx, delta=delta) + choice_list.append(choice) + + model_response_object.choices = choice_list + + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage = Usage( + completion_tokens=response_object["usage"].get("completion_tokens", 0), + prompt_tokens=response_object["usage"].get("prompt_tokens", 0), + total_tokens=response_object["usage"].get("total_tokens", 0) + ) + + + if "id" in response_object: + model_response_object.id = response_object["id"] + + if "system_fingerprint" in response_object: + model_response_object.system_fingerprint = response_object["system_fingerprint"] + + if "model" in response_object: + model_response_object.model = response_object["model"] + + yield model_response_object + await asyncio.sleep(0) + def convert_to_streaming_response(response_object: Optional[dict]=None): # used for yielding Cache hits when stream == True if response_object is None: