From 86403cd14eb2247b8f4d3576e3b6f0660a67a79d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 28 Dec 2023 19:06:49 +0530 Subject: [PATCH] fix(vertex_ai.py): support function calling for gemini --- litellm/llms/vertex_ai.py | 171 ++++++++---------- .../tests/test_amazing_vertex_completion.py | 66 ++++++- litellm/utils.py | 25 ++- 3 files changed, 167 insertions(+), 95 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 0c3b6ff8c..b1cb6035d 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -5,7 +5,7 @@ import requests import time from typing import Callable, Optional from litellm.utils import ModelResponse, Usage, CustomStreamWrapper -import litellm +import litellm, uuid import httpx @@ -264,14 +264,13 @@ def completion( request_str = "" response_obj = None - if model in litellm.vertex_language_models: + if ( + model in litellm.vertex_language_models + or model in litellm.vertex_vision_models + ): llm_model = GenerativeModel(model) - mode = "" - request_str += f"llm_model = GenerativeModel({model})\n" - elif model in litellm.vertex_vision_models: - llm_model = GenerativeModel(model) - request_str += f"llm_model = GenerativeModel({model})\n" mode = "vision" + request_str += f"llm_model = GenerativeModel({model})\n" elif model in litellm.vertex_chat_models: llm_model = ChatModel.from_pretrained(model) mode = "chat" @@ -318,48 +317,10 @@ def completion( **optional_params, ) - if mode == "": - if "stream" in optional_params and optional_params["stream"] == True: - stream = optional_params.pop("stream") - request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - model_response = llm_model.generate_content( - prompt, - generation_config=GenerationConfig(**optional_params), - safety_settings=safety_settings, - stream=stream, - ) - optional_params["stream"] = True - return model_response - request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - response_obj = llm_model.generate_content( - prompt, - generation_config=GenerationConfig(**optional_params), - safety_settings=safety_settings, - ) - completion_response = response_obj.text - response_obj = response_obj._raw_response - elif mode == "vision": + if mode == "vision": print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") - + tools = optional_params.pop("tools", None) prompt, images = _gemini_vision_convert_messages(messages=messages) content = [prompt] + images if "stream" in optional_params and optional_params["stream"] == True: @@ -379,6 +340,7 @@ def completion( generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=True, + tools=tools, ) optional_params["stream"] = True return model_response @@ -399,9 +361,35 @@ def completion( contents=content, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, + tools=tools, ) - completion_response = response.text + + if tools is not None and hasattr( + response.candidates[0].content.parts[0], "function_call" + ): + function_call = response.candidates[0].content.parts[0].function_call + args_dict = {} + for k, v in function_call.args.items(): + args_dict[k] = v + args_str = json.dumps(args_dict) + message = litellm.Message( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": { + "arguments": args_str, + "name": function_call.name, + }, + "type": "function", + } + ], + ) + completion_response = message + else: + completion_response = response.text response_obj = response._raw_response + optional_params["tools"] = tools elif mode == "chat": chat = llm_model.start_chat() request_str += f"chat = llm_model.start_chat()\n" @@ -479,7 +467,9 @@ def completion( ) ## RESPONSE OBJECT - if len(str(completion_response)) > 0: + if isinstance(completion_response, litellm.Message): + model_response["choices"][0]["message"] = completion_response + elif len(str(completion_response)) > 0: model_response["choices"][0]["message"]["content"] = str( completion_response ) @@ -533,26 +523,10 @@ async def async_completion( try: from vertexai.preview.generative_models import GenerationConfig - if mode == "": - # gemini-pro - chat = llm_model.start_chat() - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - response_obj = await chat.send_message_async( - prompt, generation_config=GenerationConfig(**optional_params) - ) - completion_response = response_obj.text - response_obj = response_obj._raw_response - elif mode == "vision": + if mode == "vision": print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") + tools = optional_params.pop("tools", None) prompt, images = _gemini_vision_convert_messages(messages=messages) content = [prompt] + images @@ -570,10 +544,37 @@ async def async_completion( ## LLM Call response = await llm_model._generate_content_async( - contents=content, generation_config=GenerationConfig(**optional_params) + contents=content, + generation_config=GenerationConfig(**optional_params), + tools=tools, ) - completion_response = response.text + + if tools is not None and hasattr( + response.candidates[0].content.parts[0], "function_call" + ): + function_call = response.candidates[0].content.parts[0].function_call + args_dict = {} + for k, v in function_call.args.items(): + args_dict[k] = v + args_str = json.dumps(args_dict) + message = litellm.Message( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": { + "arguments": args_str, + "name": function_call.name, + }, + "type": "function", + } + ], + ) + completion_response = message + else: + completion_response = response.text response_obj = response._raw_response + optional_params["tools"] = tools elif mode == "chat": # chat-bison etc. chat = llm_model.start_chat() @@ -609,7 +610,9 @@ async def async_completion( ) ## RESPONSE OBJECT - if len(str(completion_response)) > 0: + if isinstance(completion_response, litellm.Message): + model_response["choices"][0]["message"] = completion_response + elif len(str(completion_response)) > 0: model_response["choices"][0]["message"]["content"] = str( completion_response ) @@ -661,33 +664,14 @@ async def async_streaming( """ from vertexai.preview.generative_models import GenerationConfig - if mode == "": - # gemini-pro - chat = llm_model.start_chat() + if mode == "vision": stream = optional_params.pop("stream") - request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - response = await chat.send_message_async( - prompt, generation_config=GenerationConfig(**optional_params), stream=stream - ) - optional_params["stream"] = True - elif mode == "vision": - stream = optional_params.pop("stream") - + tools = optional_params.pop("tools", None) print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") prompt, images = _gemini_vision_convert_messages(messages=messages) content = [prompt] + images - stream = optional_params.pop("stream") request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" logging_obj.pre_call( input=prompt, @@ -698,12 +682,13 @@ async def async_streaming( }, ) - response = llm_model._generate_content_streaming_async( + response = await llm_model._generate_content_streaming_async( contents=content, generation_config=GenerationConfig(**optional_params), - stream=True, + tools=tools, ) optional_params["stream"] = True + optional_params["tools"] = tools elif mode == "chat": chat = llm_model.start_chat() optional_params.pop( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index eb620273f..755788be4 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -98,7 +98,8 @@ def test_vertex_ai(): litellm.vertex_project = "reliablekeys" test_models = random.sample(test_models, 1) - test_models += litellm.vertex_language_models # always test gemini-pro + # test_models += litellm.vertex_language_models # always test gemini-pro + test_models = litellm.vertex_language_models # always test gemini-pro for model in test_models: try: if model in [ @@ -303,6 +304,69 @@ def test_gemini_pro_vision(): # test_gemini_pro_vision() +def gemini_pro_function_calling(): + load_vertex_ai_credentials() + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + completion = litellm.completion( + model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" + ) + print(f"completion: {completion}") + + +# gemini_pro_function_calling() + + +async def gemini_pro_async_function_calling(): + load_vertex_ai_credentials() + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + completion = await litellm.acompletion( + model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" + ) + print(f"completion: {completion}") + + +asyncio.run(gemini_pro_async_function_calling()) + # Extra gemini Vision tests for completion + stream, async, async + stream # if we run into issues with gemini, we will also add these to our ci/cd pipeline # def test_gemini_pro_vision_stream(): diff --git a/litellm/utils.py b/litellm/utils.py index 8f73105f7..af926d362 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2939,6 +2939,7 @@ def get_optional_params( custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure" + and custom_llm_provider != "vertex_ai" ): if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": # ollama actually supports json output @@ -3238,7 +3239,14 @@ def get_optional_params( optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "vertex_ai": ## check if unsupported param passed in - supported_params = ["temperature", "top_p", "max_tokens", "stream"] + supported_params = [ + "temperature", + "top_p", + "max_tokens", + "stream", + "tools", + "tool_choice", + ] _check_valid_arg(supported_params=supported_params) if temperature is not None: @@ -3249,6 +3257,21 @@ def get_optional_params( optional_params["stream"] = stream if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens + if tools is not None and isinstance(tools, list): + from vertexai.preview import generative_models + + gtools = [] + for tool in tools: + gtool = generative_models.FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"].get("description", ""), + parameters=tool["function"].get("parameters", {}), + ) + gtool_func_declaration = generative_models.Tool( + function_declarations=[gtool] + ) + gtools.append(gtool_func_declaration) + optional_params["tools"] = gtools elif custom_llm_provider == "sagemaker": ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]