fix(vertex_httpx.py): fix gtool handling

This commit is contained in:
Krrish Dholakia 2024-07-09 08:01:49 -07:00
parent 54c31e5af7
commit a784f7d8df
2 changed files with 9 additions and 18 deletions

View file

@ -394,16 +394,19 @@ class VertexGeminiConfig:
google_search_tool: Optional[dict] = None google_search_tool: Optional[dict] = None
for tool in value: for tool in value:
# check if grounding # check if grounding
_search_tool = tool.get("googleSearchRetrieval", None) try:
if google_search_tool is not None:
google_search_tool = _search_tool
else:
gtool_func_declaration = FunctionDeclaration( gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"], name=tool["function"]["name"],
description=tool["function"].get("description", ""), description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}), parameters=tool["function"].get("parameters", {}),
) )
gtool_func_declarations.append(gtool_func_declaration) gtool_func_declarations.append(gtool_func_declaration)
except KeyError:
# assume it's a provider-specific param
verbose_logger.warning(
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
google_search_tool = tool
_tools = Tools(function_declarations=gtool_func_declarations) _tools = Tools(function_declarations=gtool_func_declarations)
if google_search_tool is not None: if google_search_tool is not None:
_tools["googleSearchRetrieval"] = google_search_tool _tools["googleSearchRetrieval"] = google_search_tool

View file

@ -682,18 +682,6 @@ def test_gemini_pro_grounding():
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
tools = [{"googleSearchRetrieval": {}}] tools = [{"googleSearchRetrieval": {}}]
litellm.set_verbose = True litellm.set_verbose = True
@ -710,8 +698,8 @@ def test_gemini_pro_grounding():
tools=tools, tools=tools,
client=client, client=client,
) )
except Exception: except Exception as e:
pass print("Received Exception - {}".format(str(e)))
mock_call.assert_called_once() mock_call.assert_called_once()