fix(vertex_ai.py): support function calling for gemini

This commit is contained in:
Krrish Dholakia 2023-12-28 19:06:49 +05:30
parent a1484171b5
commit 86403cd14e
3 changed files with 167 additions and 95 deletions

View file

@ -2939,6 +2939,7 @@ def get_optional_params(
custom_llm_provider != "openai"
and custom_llm_provider != "text-completion-openai"
and custom_llm_provider != "azure"
and custom_llm_provider != "vertex_ai"
):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output
@ -3238,7 +3239,14 @@ def get_optional_params(
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai":
## check if unsupported param passed in
supported_params = ["temperature", "top_p", "max_tokens", "stream"]
supported_params = [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
]
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
@ -3249,6 +3257,21 @@ def get_optional_params(
optional_params["stream"] = stream
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
if tools is not None and isinstance(tools, list):
from vertexai.preview import generative_models
gtools = []
for tool in tools:
gtool = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declaration = generative_models.Tool(
function_declarations=[gtool]
)
gtools.append(gtool_func_declaration)
optional_params["tools"] = gtools
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]