feat(vertex_httpx.py): add support for gemini 'grounding'

Adds support for https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding#rest
This commit is contained in:
Krrish Dholakia 2024-07-08 21:36:47 -07:00
parent a986413df3
commit 7541478459
3 changed files with 78 additions and 10 deletions

View file

@ -391,16 +391,23 @@ class VertexGeminiConfig:
optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
google_search_tool: Optional[dict] = None
for tool in value:
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
Tools(function_declarations=gtool_func_declarations)
]
# check if grounding
_search_tool = tool.get("googleSearchRetrieval", None)
if google_search_tool is not None:
google_search_tool = _search_tool
else:
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
_tools = Tools(function_declarations=gtool_func_declarations)
if google_search_tool is not None:
_tools["googleSearchRetrieval"] = google_search_tool
optional_params["tools"] = [_tools]
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):

View file

@ -677,6 +677,57 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}")
def test_gemini_pro_grounding():
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
tools = [{"googleSearchRetrieval": {}}]
litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post", new=MagicMock()) as mock_call:
try:
resp = litellm.completion(
model="vertex_ai_beta/gemini-1.0-pro-001",
messages=[{"role": "user", "content": "Who won the world cup?"}],
tools=tools,
client=client,
)
except Exception:
pass
mock_call.assert_called_once()
print(mock_call.call_args.kwargs["json"]["tools"][0])
assert (
"googleSearchRetrieval"
in mock_call.call_args.kwargs["json"]["tools"][0]
)
except litellm.InternalServerError:
pass
except litellm.RateLimitError:
pass
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
@pytest.mark.parametrize(
"model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"]

View file

@ -94,6 +94,14 @@ class FunctionDeclaration(TypedDict, total=False):
response: Schema
class VertexAISearch(TypedDict, total=False):
datastore: Required[str]
class Retrieval(TypedDict):
source: VertexAISearch
class FunctionCallingConfig(TypedDict, total=False):
mode: Literal["ANY", "AUTO", "NONE"]
allowed_function_names: List[str]
@ -147,8 +155,10 @@ class GenerationConfig(TypedDict, total=False):
response_mime_type: Literal["text/plain", "application/json"]
class Tools(TypedDict):
class Tools(TypedDict, total=False):
function_declarations: List[FunctionDeclaration]
googleSearchRetrieval: dict
retrieval: Retrieval
class ToolConfig(TypedDict):