(feat) async completion caching

This commit is contained in:
ishaan-jaff 2023-12-09 14:15:18 -08:00
parent 67c730e264
commit d18d5a3133

View file

@ -1056,6 +1056,19 @@ class Logging:
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
for callback in litellm._async_success_callback:
try:
if callback == "cache":
# set_cache once complete streaming response is built
print_verbose("async success_callback: reaches cache for logging!")
kwargs = self.model_call_details
if self.stream:
if "complete_streaming_response" not in kwargs:
print_verbose(f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n")
return
else:
print_verbose("async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache")
result = kwargs["complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(f"Async success callbacks: CustomLogger")
if self.stream:
@ -1599,7 +1612,12 @@ def client(original_function):
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
if kwargs.get("stream", False) == True:
return convert_to_streaming_response_async(
response_object=cached_result,
)
else:
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
return cached_result
# MODEL CALL
@ -3494,6 +3512,69 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
exception_logging(logger_fn=user_logger_fn, exception=e)
pass
async def convert_to_streaming_response_async(response_object: Optional[dict]=None):
"""
Asynchronously converts a response object to a streaming response.
Args:
response_object (Optional[dict]): The response object to be converted. Defaults to None.
Raises:
Exception: If the response object is None.
Yields:
ModelResponse: The converted streaming response object.
Returns:
None
"""
if response_object is None:
raise Exception("Error in response object format")
model_response_object = ModelResponse(stream=True)
if model_response_object is None:
raise Exception("Error in response creating model response object")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
delta = Delta(
content=choice["message"].get("content", None),
role=choice["message"]["role"],
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None)
)
finish_reason = choice.get("finish_reason", None)
if finish_reason is None:
finish_reason = choice.get("finish_details")
choice = StreamingChoices(finish_reason=finish_reason, index=idx, delta=delta)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object and response_object["usage"] is not None:
model_response_object.usage = Usage(
completion_tokens=response_object["usage"].get("completion_tokens", 0),
prompt_tokens=response_object["usage"].get("prompt_tokens", 0),
total_tokens=response_object["usage"].get("total_tokens", 0)
)
if "id" in response_object:
model_response_object.id = response_object["id"]
if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object["system_fingerprint"]
if "model" in response_object:
model_response_object.model = response_object["model"]
yield model_response_object
await asyncio.sleep(0)
def convert_to_streaming_response(response_object: Optional[dict]=None):
# used for yielding Cache hits when stream == True
if response_object is None: