forked from phoenix/litellm-mirror
(feat) async completion caching
This commit is contained in:
parent
67c730e264
commit
d18d5a3133
1 changed files with 82 additions and 1 deletions
|
@ -1056,6 +1056,19 @@ class Logging:
|
|||
start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
|
||||
for callback in litellm._async_success_callback:
|
||||
try:
|
||||
if callback == "cache":
|
||||
# set_cache once complete streaming response is built
|
||||
print_verbose("async success_callback: reaches cache for logging!")
|
||||
kwargs = self.model_call_details
|
||||
if self.stream:
|
||||
if "complete_streaming_response" not in kwargs:
|
||||
print_verbose(f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n")
|
||||
return
|
||||
else:
|
||||
print_verbose("async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache")
|
||||
result = kwargs["complete_streaming_response"]
|
||||
# only add to cache once we have a complete streaming response
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
print_verbose(f"Async success callbacks: CustomLogger")
|
||||
if self.stream:
|
||||
|
@ -1599,7 +1612,12 @@ def client(original_function):
|
|||
print_verbose(f"Cache Hit!")
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
if kwargs.get("stream", False) == True:
|
||||
return convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
else:
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
else:
|
||||
return cached_result
|
||||
# MODEL CALL
|
||||
|
@ -3494,6 +3512,69 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
|
|||
exception_logging(logger_fn=user_logger_fn, exception=e)
|
||||
pass
|
||||
|
||||
async def convert_to_streaming_response_async(response_object: Optional[dict]=None):
|
||||
"""
|
||||
Asynchronously converts a response object to a streaming response.
|
||||
|
||||
Args:
|
||||
response_object (Optional[dict]): The response object to be converted. Defaults to None.
|
||||
|
||||
Raises:
|
||||
Exception: If the response object is None.
|
||||
|
||||
Yields:
|
||||
ModelResponse: The converted streaming response object.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
model_response_object = ModelResponse(stream=True)
|
||||
|
||||
if model_response_object is None:
|
||||
raise Exception("Error in response creating model response object")
|
||||
|
||||
choice_list = []
|
||||
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
delta = Delta(
|
||||
content=choice["message"].get("content", None),
|
||||
role=choice["message"]["role"],
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=choice["message"].get("tool_calls", None)
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
if finish_reason is None:
|
||||
finish_reason = choice.get("finish_details")
|
||||
|
||||
choice = StreamingChoices(finish_reason=finish_reason, index=idx, delta=delta)
|
||||
choice_list.append(choice)
|
||||
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object and response_object["usage"] is not None:
|
||||
model_response_object.usage = Usage(
|
||||
completion_tokens=response_object["usage"].get("completion_tokens", 0),
|
||||
prompt_tokens=response_object["usage"].get("prompt_tokens", 0),
|
||||
total_tokens=response_object["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "system_fingerprint" in response_object:
|
||||
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
yield model_response_object
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def convert_to_streaming_response(response_object: Optional[dict]=None):
|
||||
# used for yielding Cache hits when stream == True
|
||||
if response_object is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue