diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 67db9b61ca..eb197cc01c 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -416,6 +416,46 @@ def test_gemini_pro_function_calling(): # gemini_pro_function_calling() +def test_gemini_pro_function_calling_streaming(): + load_vertex_ai_credentials() + litellm.set_verbose = True + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + completion = litellm.completion( + model="gemini-pro", + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + ) + print(f"completion: {completion}") + # assert completion.choices[0].message.content is None + # assert len(completion.choices[0].message.tool_calls) == 1 + for chunk in completion: + print(f"chunk: {chunk}") + + raise Exception("it worked!") + + @pytest.mark.asyncio async def test_gemini_pro_async_function_calling(): load_vertex_ai_credentials() diff --git a/litellm/utils.py b/litellm/utils.py index d36ba4e1a4..e0b1f42172 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9193,7 +9193,41 @@ class CustomStreamWrapper: try: if hasattr(chunk, "candidates") == True: try: - completion_obj["content"] = chunk.text + try: + completion_obj["content"] = chunk.text + except Exception as e: + if "Part has no text." in str(e): + ## check for function calling + function_call = ( + chunk.candidates[0] + .content.parts[0] + .function_call + ) + args_dict = {} + for k, v in function_call.args.items(): + args_dict[k] = v + args_str = json.dumps(args_dict) + _delta_obj = litellm.utils.Delta( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": { + "arguments": args_str, + "name": function_call.name, + }, + "type": "function", + } + ], + ) + _streaming_response = StreamingChoices( + delta=_delta_obj + ) + _model_response = ModelResponse(stream=True) + _model_response.choices = [_streaming_response] + response_obj = {"original_chunk": _model_response} + else: + raise e if ( hasattr(chunk.candidates[0], "finish_reason") and chunk.candidates[0].finish_reason.name @@ -9204,7 +9238,7 @@ class CustomStreamWrapper: chunk.candidates[0].finish_reason.name ) ) - except: + except Exception as e: if chunk.candidates[0].finish_reason.name == "SAFETY": raise Exception( f"The response was blocked by VertexAI. {str(chunk)}"