fix(utils.py): support streaming cached response logging

This commit is contained in:
Krrish Dholakia 2024-02-21 17:53:14 -08:00
parent 0733bf1e7a
commit f8b233b653
3 changed files with 114 additions and 33 deletions

View file

@ -2328,6 +2328,13 @@ def client(original_function):
model_response_object=ModelResponse(),
stream=kwargs.get("stream", False),
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
elif call_type == CallTypes.embedding.value and isinstance(
cached_result, dict
):
@ -2624,28 +2631,6 @@ def client(original_function):
cached_result, list
):
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
@ -2685,15 +2670,44 @@ def client(original_function):
additional_args=None,
stream=kwargs.get("stream", False),
)
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
if kwargs.get("stream", False) == False:
# LOG SUCCESS
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
return cached_result
elif (
call_type == CallTypes.aembedding.value
@ -4296,7 +4310,9 @@ def get_optional_params(
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)]
optional_params["tools"] = [
generative_models.Tool(function_declarations=gtool_func_declarations)
]
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
@ -8524,6 +8540,19 @@ class CustomStreamWrapper:
]
elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "cached_response":
response_obj = {
"text": chunk.choices[0].delta.content,
"is_finished": True,
"finish_reason": chunk.choices[0].finish_reason,
}
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
@ -8732,6 +8761,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: