(feat) caching - streaming caching support

This commit is contained in:
ishaan-jaff 2023-12-08 11:48:11 -08:00
parent 9b0afbe2cb
commit 6e8ad10991
2 changed files with 54 additions and 19 deletions

View file

@ -817,19 +817,6 @@ class Logging:
if callback == "api_manager":
print_verbose("reaches api manager for updating model cost")
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
if callback == "cache":
if litellm.cache != None and self.model_call_details.get('optional_params', {}).get('stream', False) == True:
litellm_call_id = self.litellm_params["litellm_call_id"]
if litellm_call_id in self.litellm_params["stream_response"]:
# append for the given call_id
if self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] == "default":
self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] = result["content"] # handle first try
else:
self.litellm_params["stream_response"][litellm_call_id]["choices"][0]["message"]["content"] += result["content"]
else: # init a streaming response for this call id
new_model_response = ModelResponse(choices=[Choices(message=Message(content="default"))])
self.litellm_params["stream_response"][litellm_call_id] = new_model_response
litellm.cache.add_cache(self.litellm_params["stream_response"][litellm_call_id], **self.model_call_details)
if callback == "promptlayer":
print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event(
@ -937,6 +924,16 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if callback == "cache":
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
kwargs = self.model_call_details
if self.stream:
if "complete_streaming_response" not in kwargs:
return
else:
result = kwargs["complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if callback == "traceloop":
deep_copy = {}
for k, v in self.model_call_details.items():
@ -1401,7 +1398,6 @@ def client(original_function):
print_verbose(f"Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
print_verbose(f"Cache Hit!")
if "detail" in cached_result:
# implies an error occurred
pass
@ -1409,7 +1405,7 @@ def client(original_function):
call_type = original_function.__name__
print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}")
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse(), stream = kwargs.get("stream", False))
elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, response_type="embedding")
else:
@ -3414,12 +3410,54 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
exception_logging(logger_fn=user_logger_fn, exception=e)
pass
def convert_to_streaming_response(response_object: Optional[dict]=None):
# used for yielding Cache hits when stream == True
if response_object is None:
raise Exception("Error in response object format")
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion"):
model_response_object = ModelResponse(stream=True)
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
delta = Delta(
content=choice["message"].get("content", None),
role=choice["message"]["role"],
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None)
)
finish_reason = choice.get("finish_reason", None)
if finish_reason == None:
# gpt-4 vision can return 'finish_reason' or 'finish_details'
finish_reason = choice.get("finish_details")
choice = StreamingChoices(finish_reason=finish_reason, index=idx, delta=delta)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object and response_object["usage"] is not None:
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if "id" in response_object:
model_response_object.id = response_object["id"]
if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object["system_fingerprint"]
if "model" in response_object:
model_response_object.model = response_object["model"]
yield model_response_object
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False):
try:
if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)):
if response_object is None or model_response_object is None:
raise Exception("Error in response object format")
if stream == True:
# for returning cached responses, we need to yield a generator
return convert_to_streaming_response(
response_object=response_object
)
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(