fix(vertex_ai.py): use chat_messages_with_history for async + streaming calls

This commit is contained in:
Krrish Dholakia 2024-05-19 12:30:24 -07:00
parent f9ab72841a
commit 65aacc0c1a
2 changed files with 17 additions and 17 deletions

View file

@ -336,7 +336,7 @@ def _process_gemini_image(image_url: str) -> PartType:
raise e
def _gemini_convert_messages_text(messages: list) -> List[ContentType]:
def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
@ -680,6 +680,7 @@ def completion(
"model_response": model_response,
"encoding": encoding,
"messages": messages,
"request_str": request_str,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
@ -698,7 +699,7 @@ def completion(
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
content = _gemini_convert_messages_text(messages=messages)
content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream == True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
@ -972,12 +973,12 @@ async def async_completion(
mode: str,
prompt: str,
model: str,
messages: list,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
request_str: str,
print_verbose: Callable,
logging_obj,
encoding=None,
messages=None,
print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@ -997,8 +998,7 @@ async def async_completion(
tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@ -1198,11 +1198,11 @@ async def async_streaming(
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
messages: list,
print_verbose: Callable,
logging_obj,
request_str: str,
encoding=None,
messages=None,
print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@ -1219,8 +1219,8 @@ async def async_streaming(
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,

View file

@ -16,7 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests
import json
import os
import tempfile
from litellm.llms.vertex_ai import _gemini_convert_messages_text
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
litellm.num_retries = 3
litellm.cache = None
@ -590,7 +590,7 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_gemini_pro_function_calling(sync_mode):
try:
@ -972,6 +972,6 @@ def test_prompt_factory():
# Now the assistant can reply with the result of the tool call.
]
translated_messages = _gemini_convert_messages_text(messages=messages)
translated_messages = _gemini_convert_messages_with_history(messages=messages)
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")