fix the issue when using multiple tools in gemini

This commit is contained in:
kan-bayashi 2024-02-14 13:09:35 +09:00
parent fa74adb041
commit a4e33c8c67

View file

@ -4215,18 +4215,15 @@ def get_optional_params(
if tools is not None and isinstance(tools, list): if tools is not None and isinstance(tools, list):
from vertexai.preview import generative_models from vertexai.preview import generative_models
gtools = [] gtool_func_declarations = []
for tool in tools: for tool in tools:
gtool = generative_models.FunctionDeclaration( gtool_func_declaration = generative_models.FunctionDeclaration(
name=tool["function"]["name"], name=tool["function"]["name"],
description=tool["function"].get("description", ""), description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}), parameters=tool["function"].get("parameters", {}),
) )
gtool_func_declaration = generative_models.Tool( gtool_func_declarations.append(gtool_func_declaration)
function_declarations=[gtool] optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)]
)
gtools.append(gtool_func_declaration)
optional_params["tools"] = gtools
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]