mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(vertex_httpx.py): support tool calling w/ streaming for vertex ai + gemini
This commit is contained in:
parent
2452753e08
commit
faa88a1ab1
6 changed files with 111 additions and 16 deletions
|
@ -12,6 +12,9 @@ from typing import Tuple
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -3034,8 +3037,11 @@ def test_completion_claude_3_function_call_with_streaming():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["gemini/gemini-1.5-flash"]
|
||||
) # "claude-3-opus-20240229",
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude_3_function_call_with_streaming():
|
||||
async def test_acompletion_claude_3_function_call_with_streaming(model):
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
|
@ -3066,7 +3072,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
|
|||
try:
|
||||
# test without max tokens
|
||||
response = await acompletion(
|
||||
model="claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
|
@ -3453,3 +3459,55 @@ def test_aamazing_unit_test_custom_stream_wrapper_n():
|
|||
assert (
|
||||
chunk_dict == chunks[idx]
|
||||
), f"idx={idx} translated chunk = {chunk_dict} != openai chunk = {chunks[idx]}"
|
||||
|
||||
|
||||
def test_unit_test_custom_stream_wrapper_function_call():
|
||||
"""
|
||||
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
|
||||
"""
|
||||
from litellm.types.llms.openai import ChatCompletionDeltaChunk
|
||||
|
||||
litellm.set_verbose = False
|
||||
delta: ChatCompletionDeltaChunk = {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {"arguments": '"}'},
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
chunk = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1694268190,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{"index": 0, "delta": delta, "finish_reason": "stop"}],
|
||||
}
|
||||
chunk = litellm.ModelResponse(**chunk, stream=True)
|
||||
|
||||
completion_stream = ModelResponseIterator(model_response=chunk)
|
||||
|
||||
response = litellm.CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model="gpt-3.5-turbo",
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey"}],
|
||||
stream=True,
|
||||
call_type="completion",
|
||||
start_time=time.time(),
|
||||
litellm_call_id="12345",
|
||||
function_id="1245",
|
||||
),
|
||||
)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
for chunk in response:
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
assert finish_reason == "tool_calls"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue