fix(utils.py): support streaming cached response logging

This commit is contained in:
Krrish Dholakia 2024-02-21 17:53:14 -08:00
parent 0733bf1e7a
commit f8b233b653
3 changed files with 114 additions and 33 deletions

View file

@ -936,7 +936,14 @@
"mode": "chat" "mode": "chat"
}, },
"openrouter/mistralai/mistral-7b-instruct": { "openrouter/mistralai/mistral-7b-instruct": {
"max_tokens": 4096, "max_tokens": 8192,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.00000013,
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/mistralai/mistral-7b-instruct:free": {
"max_tokens": 8192,
"input_cost_per_token": 0.0, "input_cost_per_token": 0.0,
"output_cost_per_token": 0.0, "output_cost_per_token": 0.0,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",

View file

@ -2,7 +2,7 @@
## This test asserts the type of data passed into each method of the custom callback handler ## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
from datetime import datetime from datetime import datetime
import pytest import pytest, uuid
from pydantic import BaseModel from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
@ -795,6 +795,50 @@ async def test_async_completion_azure_caching():
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_async_completion_azure_caching_streaming():
litellm.set_verbose = True
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = uuid.uuid4()
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response1:
continue
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response2:
continue
await asyncio.sleep(1) # success callbacks are done in parallel
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert (
len(customHandler_caching.states) > 4
) # pre, post, streaming .., success, success
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_embedding_azure_caching(): async def test_async_embedding_azure_caching():
print("Testing custom callback input - Azure Caching") print("Testing custom callback input - Azure Caching")

View file

@ -2328,6 +2328,13 @@ def client(original_function):
model_response_object=ModelResponse(), model_response_object=ModelResponse(),
stream=kwargs.get("stream", False), stream=kwargs.get("stream", False),
) )
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
elif call_type == CallTypes.embedding.value and isinstance( elif call_type == CallTypes.embedding.value and isinstance(
cached_result, dict cached_result, dict
): ):
@ -2624,28 +2631,6 @@ def client(original_function):
cached_result, list cached_result, list
): ):
print_verbose(f"Cache Hit!") print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
# LOG SUCCESS
cache_hit = True cache_hit = True
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
( (
@ -2685,6 +2670,35 @@ def client(original_function):
additional_args=None, additional_args=None,
stream=kwargs.get("stream", False), stream=kwargs.get("stream", False),
) )
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) == True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
if kwargs.get("stream", False) == False:
# LOG SUCCESS
asyncio.create_task( asyncio.create_task(
logging_obj.async_success_handler( logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit cached_result, start_time, end_time, cache_hit
@ -4296,7 +4310,9 @@ def get_optional_params(
parameters=tool["function"].get("parameters", {}), parameters=tool["function"].get("parameters", {}),
) )
gtool_func_declarations.append(gtool_func_declaration) gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)] optional_params["tools"] = [
generative_models.Tool(function_declarations=gtool_func_declarations)
]
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
@ -8524,6 +8540,19 @@ class CustomStreamWrapper:
] ]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk) response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "cached_response":
response_obj = {
"text": chunk.choices[0].delta.content,
"is_finished": True,
"finish_reason": chunk.choices[0].finish_reason,
}
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
@ -8732,6 +8761,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream: