fix(vertex_ai.py): use chat_messages_with_history for async + streaming calls

This commit is contained in:
Krrish Dholakia 2024-05-19 12:30:24 -07:00
parent f9ab72841a
commit 65aacc0c1a
2 changed files with 17 additions and 17 deletions

View file

@ -336,7 +336,7 @@ def _process_gemini_image(image_url: str) -> PartType:
raise e
def _gemini_convert_messages_text(messages: list) -> List[ContentType]:
def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
@ -680,6 +680,7 @@ def completion(
"model_response": model_response,
"encoding": encoding,
"messages": messages,
"request_str": request_str,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
@ -698,7 +699,7 @@ def completion(
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
content = _gemini_convert_messages_text(messages=messages)
content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream == True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
@ -972,12 +973,12 @@ async def async_completion(
mode: str,
prompt: str,
model: str,
messages: list,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
request_str: str,
print_verbose: Callable,
logging_obj,
encoding=None,
messages=None,
print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@ -997,8 +998,7 @@ async def async_completion(
tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@ -1198,11 +1198,11 @@ async def async_streaming(
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
messages: list,
print_verbose: Callable,
logging_obj,
request_str: str,
encoding=None,
messages=None,
print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@ -1219,8 +1219,8 @@ async def async_streaming(
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,