From 6e9267ca66e52bf0d6aeb250b0dc7506072f4fbb Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Wed, 20 Dec 2023 15:32:44 -0500 Subject: [PATCH 1/2] Make vertex_chat work with generate_content --- litellm/llms/vertex_ai.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 7cfc91701..7601d0d4c 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -244,21 +244,19 @@ def completion( return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params) if mode == "": - chat = llm_model.start_chat() - request_str+= f"chat = llm_model.start_chat()\n" if "stream" in optional_params and optional_params["stream"] == True: stream = optional_params.pop("stream") request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream) + model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream) optional_params["stream"] = True return model_response request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings) + response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings) completion_response = response_obj.text response_obj = response_obj._raw_response elif mode == "vision": From 2362544344c19b254051ae29c288cb74728ea9d0 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Thu, 21 Dec 2023 09:58:06 -0500 Subject: [PATCH 2/2] Update the request_str --- litellm/llms/vertex_ai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 7601d0d4c..e25a0b925 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -247,13 +247,13 @@ def completion( if "stream" in optional_params and optional_params["stream"] == True: stream = optional_params.pop("stream") - request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" + request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream) optional_params["stream"] = True return model_response - request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" + request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)