fix(vertex_httpx.py): google search grounding fix

This commit is contained in:
Krrish Dholakia 2024-07-14 08:06:17 -07:00
parent 385da04d72
commit 82ca7af6df
3 changed files with 66 additions and 9 deletions

View file

@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem';
## 🆕 `vertex_ai_beta/` route
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk).
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk). This implementation uses [VertexAI's REST API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#syntax).
```python
from litellm import completion
@ -377,6 +377,54 @@ curl http://0.0.0.0:4000/v1/chat/completions \
</TabItem>
</Tabs>
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
If this was your initial VertexAI Grounding code,
```python
import vertexai
vertexai.init(project=project_id, location="us-central1")
model = GenerativeModel("gemini-1.5-flash-001")
# Use Google Search for grounding
tool = Tool.from_google_search_retrieval(grounding.GoogleSearchRetrieval(disable_attributon=False))
prompt = "When is the next total solar eclipse in US?"
response = model.generate_content(
prompt,
tools=[tool],
generation_config=GenerationConfig(
temperature=0.0,
),
)
print(response)
```
then, this is what it looks like now
```python
from litellm import completion
# !gcloud auth application-default login - run this to add vertex credentials to your env
tools = [{"googleSearchRetrieval": {"disable_attributon": False}}] # 👈 ADD GOOGLE SEARCH
resp = litellm.completion(
model="vertex_ai_beta/gemini-1.0-pro-001",
messages=[{"role": "user", "content": "Who won the world cup?"}],
tools=tools,
vertex_project="project-id"
)
print(resp)
```
## Pre-requisites
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
* Authentication:

View file

@ -396,7 +396,8 @@ class VertexGeminiConfig:
optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
google_search_tool: Optional[dict] = None
googleSearchRetrieval: Optional[dict] = None
provider_specific_tools: List[dict] = []
for tool in value:
# check if grounding
try:
@ -411,11 +412,14 @@ class VertexGeminiConfig:
verbose_logger.warning(
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
google_search_tool = tool
_tools = Tools(function_declarations=gtool_func_declarations)
if google_search_tool is not None:
_tools["googleSearchRetrieval"] = google_search_tool
optional_params["tools"] = [_tools]
if tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
_tools = Tools(
function_declarations=gtool_func_declarations,
)
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
optional_params["tools"] = [_tools] + provider_specific_tools
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):

View file

@ -677,12 +677,13 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}")
def test_gemini_pro_grounding():
@pytest.mark.parametrize("value_in_dict", [{}, {"disable_attribution": False}]) #
def test_gemini_pro_grounding(value_in_dict):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
tools = [{"googleSearchRetrieval": {}}]
tools = [{"googleSearchRetrieval": value_in_dict}]
litellm.set_verbose = True
@ -709,6 +710,10 @@ def test_gemini_pro_grounding():
"googleSearchRetrieval"
in mock_call.call_args.kwargs["json"]["tools"][0]
)
assert (
mock_call.call_args.kwargs["json"]["tools"][0]["googleSearchRetrieval"]
== value_in_dict
)
except litellm.InternalServerError:
pass