diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index ce9e6286f..c58541a7d 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -638,73 +638,66 @@ async def test_gemini_pro_function_calling(sync_mode): # gemini_pro_function_calling() -@pytest.mark.parametrize("stream", [False, True]) @pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.asyncio -async def test_gemini_pro_function_calling_streaming(stream, sync_mode): +async def test_gemini_pro_function_calling_streaming(sync_mode): load_vertex_ai_credentials() litellm.set_verbose = True - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + data = { + "model": "vertex_ai/gemini-pro", + "messages": [ + { + "role": "user", + "content": "Call the submit_cities function with San Francisco and New York", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "submit_cities", + "description": "Submits a list of cities", + "parameters": { + "type": "object", + "properties": { + "cities": {"type": "array", "items": {"type": "string"}} }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + "required": ["cities"], }, - "required": ["location"], }, - }, - } - ] - messages = [ - { - "role": "user", - "content": "What's the weather like in Boston today in fahrenheit?", - } - ] - optional_params = { - "tools": tools, + } + ], "tool_choice": "auto", "n": 1, - "stream": stream, + "stream": True, "temperature": 0.1, } + chunks = [] try: if sync_mode == True: - response = litellm.completion( - model="gemini-pro", messages=messages, **optional_params - ) + response = litellm.completion(**data) print(f"completion: {response}") - if stream == True: - # assert completion.choices[0].message.content is None - # assert len(completion.choices[0].message.tool_calls) == 1 - for chunk in response: - assert isinstance(chunk, litellm.ModelResponse) - else: - assert isinstance(response, litellm.ModelResponse) + for chunk in response: + chunks.append(chunk) + assert isinstance(chunk, litellm.ModelResponse) else: - response = await litellm.acompletion( - model="gemini-pro", messages=messages, **optional_params - ) + response = await litellm.acompletion(**data) print(f"completion: {response}") - if stream == True: - # assert completion.choices[0].message.content is None - # assert len(completion.choices[0].message.tool_calls) == 1 - async for chunk in response: - print(f"chunk: {chunk}") - assert isinstance(chunk, litellm.ModelResponse) - else: - assert isinstance(response, litellm.ModelResponse) + assert isinstance(response, litellm.CustomStreamWrapper) + + async for chunk in response: + print(f"chunk: {chunk}") + chunks.append(chunk) + assert isinstance(chunk, litellm.ModelResponse) + + complete_response = litellm.stream_chunk_builder(chunks=chunks) + assert ( + complete_response.choices[0].message.content is not None + or len(complete_response.choices[0].message.tool_calls) > 0 + ) + print(f"complete_response: {complete_response}") except litellm.APIError as e: pass except litellm.RateLimitError as e: diff --git a/litellm/utils.py b/litellm/utils.py index f77baf8bd..cd21b390b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10761,6 +10761,8 @@ class CustomStreamWrapper: else: completion_obj["content"] = str(chunk) elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): + import proto # type: ignore + if self.model.startswith("claude-3"): response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk) if response_obj is None: @@ -10798,10 +10800,24 @@ class CustomStreamWrapper: function_call = ( chunk.candidates[0].content.parts[0].function_call ) + args_dict = {} - for k, v in function_call.args.items(): - args_dict[k] = v - args_str = json.dumps(args_dict) + + # Check if it's a RepeatedComposite instance + for key, val in function_call.args.items(): + if isinstance( + val, + proto.marshal.collections.repeated.RepeatedComposite, + ): + # If so, convert to list + args_dict[key] = [v for v in val] + else: + args_dict[key] = val + + try: + args_str = json.dumps(args_dict) + except Exception as e: + raise e _delta_obj = litellm.utils.Delta( content=None, tool_calls=[