fix(vertex_ai.py): use chat_messages_with_history for async + streaming calls

This commit is contained in:
Krrish Dholakia 2024-05-19 12:30:24 -07:00
parent f9ab72841a
commit 65aacc0c1a
2 changed files with 17 additions and 17 deletions

View file

@ -336,7 +336,7 @@ def _process_gemini_image(image_url: str) -> PartType:
raise e raise e
def _gemini_convert_messages_text(messages: list) -> List[ContentType]: def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
""" """
Converts given messages from OpenAI format to Gemini format Converts given messages from OpenAI format to Gemini format
@ -680,6 +680,7 @@ def completion(
"model_response": model_response, "model_response": model_response,
"encoding": encoding, "encoding": encoding,
"messages": messages, "messages": messages,
"request_str": request_str,
"print_verbose": print_verbose, "print_verbose": print_verbose,
"client_options": client_options, "client_options": client_options,
"instances": instances, "instances": instances,
@ -698,7 +699,7 @@ def completion(
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
content = _gemini_convert_messages_text(messages=messages) content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
if stream == True: if stream == True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
@ -972,12 +973,12 @@ async def async_completion(
mode: str, mode: str,
prompt: str, prompt: str,
model: str, model: str,
messages: list,
model_response: ModelResponse, model_response: ModelResponse,
logging_obj=None, request_str: str,
request_str=None, print_verbose: Callable,
logging_obj,
encoding=None, encoding=None,
messages=None,
print_verbose=None,
client_options=None, client_options=None,
instances=None, instances=None,
vertex_project=None, vertex_project=None,
@ -997,8 +998,7 @@ async def async_completion(
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
prompt, images = _gemini_vision_convert_messages(messages=messages) content = _gemini_convert_messages_with_history(messages=messages)
content = [prompt] + images
request_str += f"response = llm_model.generate_content({content})\n" request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING ## LOGGING
@ -1198,11 +1198,11 @@ async def async_streaming(
prompt: str, prompt: str,
model: str, model: str,
model_response: ModelResponse, model_response: ModelResponse,
logging_obj=None, messages: list,
request_str=None, print_verbose: Callable,
logging_obj,
request_str: str,
encoding=None, encoding=None,
messages=None,
print_verbose=None,
client_options=None, client_options=None,
instances=None, instances=None,
vertex_project=None, vertex_project=None,
@ -1219,8 +1219,8 @@ async def async_streaming(
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages) content = _gemini_convert_messages_with_history(messages=messages)
content = [prompt] + images
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,

View file

@ -16,7 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests
import json import json
import os import os
import tempfile import tempfile
from litellm.llms.vertex_ai import _gemini_convert_messages_text from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
litellm.num_retries = 3 litellm.num_retries = 3
litellm.cache = None litellm.cache = None
@ -590,7 +590,7 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_function_calling(sync_mode): async def test_gemini_pro_function_calling(sync_mode):
try: try:
@ -972,6 +972,6 @@ def test_prompt_factory():
# Now the assistant can reply with the result of the tool call. # Now the assistant can reply with the result of the tool call.
] ]
translated_messages = _gemini_convert_messages_text(messages=messages) translated_messages = _gemini_convert_messages_with_history(messages=messages)
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages") print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")