diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index f98698df4..b52e8689f 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -336,7 +336,7 @@ def _process_gemini_image(image_url: str) -> PartType: raise e -def _gemini_convert_messages_text(messages: list) -> List[ContentType]: +def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: """ Converts given messages from OpenAI format to Gemini format @@ -680,6 +680,7 @@ def completion( "model_response": model_response, "encoding": encoding, "messages": messages, + "request_str": request_str, "print_verbose": print_verbose, "client_options": client_options, "instances": instances, @@ -698,7 +699,7 @@ def completion( print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_text(messages=messages) + content = _gemini_convert_messages_with_history(messages=messages) stream = optional_params.pop("stream", False) if stream == True: request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" @@ -972,12 +973,12 @@ async def async_completion( mode: str, prompt: str, model: str, + messages: list, model_response: ModelResponse, - logging_obj=None, - request_str=None, + request_str: str, + print_verbose: Callable, + logging_obj, encoding=None, - messages=None, - print_verbose=None, client_options=None, instances=None, vertex_project=None, @@ -997,8 +998,7 @@ async def async_completion( tools = optional_params.pop("tools", None) stream = optional_params.pop("stream", False) - prompt, images = _gemini_vision_convert_messages(messages=messages) - content = [prompt] + images + content = _gemini_convert_messages_with_history(messages=messages) request_str += f"response = llm_model.generate_content({content})\n" ## LOGGING @@ -1198,11 +1198,11 @@ async def async_streaming( prompt: str, model: str, model_response: ModelResponse, - logging_obj=None, - request_str=None, + messages: list, + print_verbose: Callable, + logging_obj, + request_str: str, encoding=None, - messages=None, - print_verbose=None, client_options=None, instances=None, vertex_project=None, @@ -1219,8 +1219,8 @@ async def async_streaming( print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") - prompt, images = _gemini_vision_convert_messages(messages=messages) - content = [prompt] + images + content = _gemini_convert_messages_with_history(messages=messages) + request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" logging_obj.pre_call( input=prompt, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index d8d2c727d..d7b2dc2d4 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -16,7 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests import json import os import tempfile -from litellm.llms.vertex_ai import _gemini_convert_messages_text +from litellm.llms.vertex_ai import _gemini_convert_messages_with_history litellm.num_retries = 3 litellm.cache = None @@ -590,7 +590,7 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -@pytest.mark.parametrize("sync_mode", [True]) +@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_gemini_pro_function_calling(sync_mode): try: @@ -972,6 +972,6 @@ def test_prompt_factory(): # Now the assistant can reply with the result of the tool call. ] - translated_messages = _gemini_convert_messages_text(messages=messages) + translated_messages = _gemini_convert_messages_with_history(messages=messages) print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")