fix(utils.py): support streaming cached response logging

This commit is contained in:
Krrish Dholakia 2024-02-21 17:53:14 -08:00
parent 0733bf1e7a
commit f8b233b653
3 changed files with 114 additions and 33 deletions

View file

@ -936,7 +936,14 @@
"mode": "chat"
},
"openrouter/mistralai/mistral-7b-instruct": {
"max_tokens": 4096,
"max_tokens": 8192,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.00000013,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/mistralai/mistral-7b-instruct:free": {
"max_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "openrouter",

View file

@ -2,7 +2,7 @@
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
import pytest, uuid
from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../.."))
@ -795,6 +795,50 @@ async def test_async_completion_azure_caching():
assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_async_completion_azure_caching_streaming():
litellm.set_verbose = True
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = uuid.uuid4()
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response1:
continue
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response2:
continue
await asyncio.sleep(1) # success callbacks are done in parallel
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert (
len(customHandler_caching.states) > 4
) # pre, post, streaming .., success, success
@pytest.mark.asyncio
async def test_async_embedding_azure_caching():
print("Testing custom callback input - Azure Caching")

View file

@ -2328,6 +2328,13 @@ def client(original_function):
model_response_object=ModelResponse(),
stream=kwargs.get("stream", False),
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
elif call_type == CallTypes.embedding.value and isinstance(
cached_result, dict
):
@ -2624,28 +2631,6 @@ def client(original_function):
cached_result, list
):
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
@ -2685,15 +2670,44 @@ def client(original_function):
additional_args=None,
stream=kwargs.get("stream", False),
)
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
if kwargs.get("stream", False) == False:
# LOG SUCCESS
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
return cached_result
elif (
call_type == CallTypes.aembedding.value
@ -4296,7 +4310,9 @@ def get_optional_params(
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)]
optional_params["tools"] = [
generative_models.Tool(function_declarations=gtool_func_declarations)
]
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
@ -8524,6 +8540,19 @@ class CustomStreamWrapper:
]
elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "cached_response":
response_obj = {
"text": chunk.choices[0].delta.content,
"is_finished": True,
"finish_reason": chunk.choices[0].finish_reason,
}
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
@ -8732,6 +8761,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: