diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 603bd3c22..fdbc1625e 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -559,8 +559,7 @@ def completion( f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" ) response = llm_model.predict( - endpoint=endpoint_path, - instances=instances + endpoint=endpoint_path, instances=instances ).predictions completion_response = response[0] @@ -585,12 +584,8 @@ def completion( "request_str": request_str, }, ) - request_str += ( - f"llm_model.predict(instances={instances})\n" - ) - response = llm_model.predict( - instances=instances - ).predictions + request_str += f"llm_model.predict(instances={instances})\n" + response = llm_model.predict(instances=instances).predictions completion_response = response[0] if ( @@ -614,7 +609,6 @@ def completion( model_response["choices"][0]["message"]["content"] = str( completion_response ) - model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["created"] = int(time.time()) model_response["model"] = model ## CALCULATING USAGE @@ -766,6 +760,7 @@ async def async_completion( Vertex AI Model Garden """ from google.cloud import aiplatform + ## LOGGING logging_obj.pre_call( input=prompt, @@ -797,11 +792,9 @@ async def async_completion( and "\nOutput:\n" in completion_response ): completion_response = completion_response.split("\nOutput:\n", 1)[1] - + elif mode == "private": - request_str += ( - f"llm_model.predict_async(instances={instances})\n" - ) + request_str += f"llm_model.predict_async(instances={instances})\n" response_obj = await llm_model.predict_async( instances=instances, ) @@ -826,7 +819,6 @@ async def async_completion( model_response["choices"][0]["message"]["content"] = str( completion_response ) - model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["created"] = int(time.time()) model_response["model"] = model ## CALCULATING USAGE @@ -954,6 +946,7 @@ async def async_streaming( response = llm_model.predict_streaming_async(prompt, **optional_params) elif mode == "custom": from google.cloud import aiplatform + stream = optional_params.pop("stream", None) ## LOGGING @@ -972,7 +965,9 @@ async def async_streaming( endpoint_path = llm_model.endpoint_path( project=vertex_project, location=vertex_location, endpoint=model ) - request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n" + request_str += ( + f"client.predict(endpoint={endpoint_path}, instances={instances})\n" + ) response_obj = await llm_model.predict( endpoint=endpoint_path, instances=instances, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 9b7473ea2..76ebde7ae 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -318,7 +318,7 @@ def test_gemini_pro_vision(): # test_gemini_pro_vision() -def gemini_pro_function_calling(): +def test_gemini_pro_function_calling(): load_vertex_ai_credentials() tools = [ { @@ -345,12 +345,15 @@ def gemini_pro_function_calling(): model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" ) print(f"completion: {completion}") + assert completion.choices[0].message.content is None + assert len(completion.choices[0].message.tool_calls) == 1 # gemini_pro_function_calling() -async def gemini_pro_async_function_calling(): +@pytest.mark.asyncio +async def test_gemini_pro_async_function_calling(): load_vertex_ai_credentials() tools = [ { @@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling(): model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" ) print(f"completion: {completion}") + assert completion.choices[0].message.content is None + assert len(completion.choices[0].message.tool_calls) == 1 + # raise Exception("it worked!") # asyncio.run(gemini_pro_async_function_calling()) diff --git a/litellm/utils.py b/litellm/utils.py index 4260ee6e1..21677890e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4274,8 +4274,8 @@ def get_optional_params( optional_params["stop_sequences"] = stop if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens - elif custom_llm_provider == "vertex_ai" and model in ( - litellm.vertex_chat_models + elif custom_llm_provider == "vertex_ai" and ( + model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models