fix(vertex_httpx.py): support tool calling w/ streaming for vertex ai + gemini

This commit is contained in:
Krrish Dholakia 2024-07-06 14:02:25 -07:00
parent 2452753e08
commit faa88a1ab1
6 changed files with 111 additions and 16 deletions

View file

@ -12,6 +12,9 @@ from typing import Tuple
import pytest
from pydantic import BaseModel
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -3034,8 +3037,11 @@ def test_completion_claude_3_function_call_with_streaming():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"model", ["gemini/gemini-1.5-flash"]
) # "claude-3-opus-20240229",
@pytest.mark.asyncio
async def test_acompletion_claude_3_function_call_with_streaming():
async def test_acompletion_claude_3_function_call_with_streaming(model):
litellm.set_verbose = True
tools = [
{
@ -3066,7 +3072,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
try:
# test without max tokens
response = await acompletion(
model="claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice="required",
@ -3453,3 +3459,55 @@ def test_aamazing_unit_test_custom_stream_wrapper_n():
assert (
chunk_dict == chunks[idx]
), f"idx={idx} translated chunk = {chunk_dict} != openai chunk = {chunks[idx]}"
def test_unit_test_custom_stream_wrapper_function_call():
"""
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
"""
from litellm.types.llms.openai import ChatCompletionDeltaChunk
litellm.set_verbose = False
delta: ChatCompletionDeltaChunk = {
"content": None,
"role": "assistant",
"tool_calls": [
{
"function": {"arguments": '"}'},
"type": "function",
"index": 0,
}
],
}
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1694268190,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{"index": 0, "delta": delta, "finish_reason": "stop"}],
}
chunk = litellm.ModelResponse(**chunk, stream=True)
completion_stream = ModelResponseIterator(model_response=chunk)
response = litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model="gpt-3.5-turbo",
custom_llm_provider="cached_response",
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
),
)
finish_reason: Optional[str] = None
for chunk in response:
if chunk.choices[0].finish_reason is not None:
finish_reason = chunk.choices[0].finish_reason
assert finish_reason == "tool_calls"