forked from phoenix/litellm-mirror
feat(vertex_httpx.py): add support for gemini 'grounding'
Adds support for https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding#rest
This commit is contained in:
parent
a986413df3
commit
7541478459
3 changed files with 78 additions and 10 deletions
|
@ -391,16 +391,23 @@ class VertexGeminiConfig:
|
|||
optional_params["presence_penalty"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
google_search_tool: Optional[dict] = None
|
||||
for tool in value:
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [
|
||||
Tools(function_declarations=gtool_func_declarations)
|
||||
]
|
||||
# check if grounding
|
||||
_search_tool = tool.get("googleSearchRetrieval", None)
|
||||
if google_search_tool is not None:
|
||||
google_search_tool = _search_tool
|
||||
else:
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
_tools = Tools(function_declarations=gtool_func_declarations)
|
||||
if google_search_tool is not None:
|
||||
_tools["googleSearchRetrieval"] = google_search_tool
|
||||
optional_params["tools"] = [_tools]
|
||||
if param == "tool_choice" and (
|
||||
isinstance(value, str) or isinstance(value, dict)
|
||||
):
|
||||
|
|
|
@ -677,6 +677,57 @@ def test_gemini_pro_vision_base64():
|
|||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_gemini_pro_grounding():
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||
},
|
||||
# User asks for their name and weather in San Francisco
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, what is your name and can you tell me the weather?",
|
||||
},
|
||||
]
|
||||
|
||||
tools = [{"googleSearchRetrieval": {}}]
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_call:
|
||||
try:
|
||||
resp = litellm.completion(
|
||||
model="vertex_ai_beta/gemini-1.0-pro-001",
|
||||
messages=[{"role": "user", "content": "Who won the world cup?"}],
|
||||
tools=tools,
|
||||
client=client,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_call.assert_called_once()
|
||||
|
||||
print(mock_call.call_args.kwargs["json"]["tools"][0])
|
||||
|
||||
assert (
|
||||
"googleSearchRetrieval"
|
||||
in mock_call.call_args.kwargs["json"]["tools"][0]
|
||||
)
|
||||
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"]
|
||||
|
|
|
@ -94,6 +94,14 @@ class FunctionDeclaration(TypedDict, total=False):
|
|||
response: Schema
|
||||
|
||||
|
||||
class VertexAISearch(TypedDict, total=False):
|
||||
datastore: Required[str]
|
||||
|
||||
|
||||
class Retrieval(TypedDict):
|
||||
source: VertexAISearch
|
||||
|
||||
|
||||
class FunctionCallingConfig(TypedDict, total=False):
|
||||
mode: Literal["ANY", "AUTO", "NONE"]
|
||||
allowed_function_names: List[str]
|
||||
|
@ -147,8 +155,10 @@ class GenerationConfig(TypedDict, total=False):
|
|||
response_mime_type: Literal["text/plain", "application/json"]
|
||||
|
||||
|
||||
class Tools(TypedDict):
|
||||
class Tools(TypedDict, total=False):
|
||||
function_declarations: List[FunctionDeclaration]
|
||||
googleSearchRetrieval: dict
|
||||
retrieval: Retrieval
|
||||
|
||||
|
||||
class ToolConfig(TypedDict):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue