mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(utils.py): support streaming cached response logging
This commit is contained in:
parent
0733bf1e7a
commit
f8b233b653
3 changed files with 114 additions and 33 deletions
|
@ -2328,6 +2328,13 @@ def client(original_function):
|
|||
model_response_object=ModelResponse(),
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
elif call_type == CallTypes.embedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
|
@ -2624,28 +2631,6 @@ def client(original_function):
|
|||
cached_result, list
|
||||
):
|
||||
print_verbose(f"Cache Hit!")
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) == True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
|
@ -2685,15 +2670,44 @@ def client(original_function):
|
|||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) == True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
if kwargs.get("stream", False) == False:
|
||||
# LOG SUCCESS
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
return cached_result
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
|
@ -4296,7 +4310,9 @@ def get_optional_params(
|
|||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)]
|
||||
optional_params["tools"] = [
|
||||
generative_models.Tool(function_declarations=gtool_func_declarations)
|
||||
]
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
## check if unsupported param passed in
|
||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
|
@ -8524,6 +8540,19 @@ class CustomStreamWrapper:
|
|||
]
|
||||
elif self.custom_llm_provider == "text-completion-openai":
|
||||
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
elif self.custom_llm_provider == "cached_response":
|
||||
response_obj = {
|
||||
"text": chunk.choices[0].delta.content,
|
||||
"is_finished": True,
|
||||
"finish_reason": chunk.choices[0].finish_reason,
|
||||
}
|
||||
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
|
@ -8732,6 +8761,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "vertex_ai"
|
||||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue