fix(utils.py): support returning caching streaming response for function calling streaming calls

This commit is contained in:
Krrish Dholakia 2024-02-26 12:31:00 -08:00
parent 9439ec0a61
commit dcca55159b
2 changed files with 134 additions and 80 deletions

View file

@ -1907,6 +1907,8 @@ def test_azure_streaming_and_function_calling():
@pytest.mark.asyncio
async def test_azure_astreaming_and_function_calling():
import uuid
tools = [
{
"type": "function",
@ -1927,7 +1929,20 @@ async def test_azure_astreaming_and_function_calling():
},
}
]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
messages = [
{
"role": "user",
"content": f"What is the weather like in Boston? {uuid.uuid4()}",
}
]
from litellm.caching import Cache
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
try:
response = await litellm.acompletion(
model="azure/gpt-4-nov-release",
@ -1938,6 +1953,7 @@ async def test_azure_astreaming_and_function_calling():
api_base=os.getenv("AZURE_FRANCE_API_BASE"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
api_version="2024-02-15-preview",
caching=True,
)
# Add any assertions here to check the response
idx = 0
@ -1957,6 +1973,36 @@ async def test_azure_astreaming_and_function_calling():
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
## CACHING TEST
print("\n\nCACHING TESTS\n\n")
response = await litellm.acompletion(
model="azure/gpt-4-nov-release",
tools=tools,
tool_choice="auto",
messages=messages,
stream=True,
api_base=os.getenv("AZURE_FRANCE_API_BASE"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
api_version="2024-02-15-preview",
caching=True,
)
# Add any assertions here to check the response
idx = 0
async for chunk in response:
print(f"chunk: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
raise e