forked from phoenix/litellm-mirror
fix(vertex_httpx.py): google search grounding fix
This commit is contained in:
parent
385da04d72
commit
82ca7af6df
3 changed files with 66 additions and 9 deletions
|
@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
## 🆕 `vertex_ai_beta/` route
|
||||
|
||||
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk).
|
||||
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk). This implementation uses [VertexAI's REST API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#syntax).
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
@ -377,6 +377,54 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
|
||||
|
||||
|
||||
If this was your initial VertexAI Grounding code,
|
||||
|
||||
```python
|
||||
import vertexai
|
||||
|
||||
vertexai.init(project=project_id, location="us-central1")
|
||||
|
||||
model = GenerativeModel("gemini-1.5-flash-001")
|
||||
|
||||
# Use Google Search for grounding
|
||||
tool = Tool.from_google_search_retrieval(grounding.GoogleSearchRetrieval(disable_attributon=False))
|
||||
|
||||
prompt = "When is the next total solar eclipse in US?"
|
||||
response = model.generate_content(
|
||||
prompt,
|
||||
tools=[tool],
|
||||
generation_config=GenerationConfig(
|
||||
temperature=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
then, this is what it looks like now
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
|
||||
# !gcloud auth application-default login - run this to add vertex credentials to your env
|
||||
|
||||
tools = [{"googleSearchRetrieval": {"disable_attributon": False}}] # 👈 ADD GOOGLE SEARCH
|
||||
|
||||
resp = litellm.completion(
|
||||
model="vertex_ai_beta/gemini-1.0-pro-001",
|
||||
messages=[{"role": "user", "content": "Who won the world cup?"}],
|
||||
tools=tools,
|
||||
vertex_project="project-id"
|
||||
)
|
||||
|
||||
print(resp)
|
||||
```
|
||||
|
||||
|
||||
## Pre-requisites
|
||||
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
||||
* Authentication:
|
||||
|
|
|
@ -396,7 +396,8 @@ class VertexGeminiConfig:
|
|||
optional_params["presence_penalty"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
google_search_tool: Optional[dict] = None
|
||||
googleSearchRetrieval: Optional[dict] = None
|
||||
provider_specific_tools: List[dict] = []
|
||||
for tool in value:
|
||||
# check if grounding
|
||||
try:
|
||||
|
@ -411,11 +412,14 @@ class VertexGeminiConfig:
|
|||
verbose_logger.warning(
|
||||
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
||||
)
|
||||
google_search_tool = tool
|
||||
_tools = Tools(function_declarations=gtool_func_declarations)
|
||||
if google_search_tool is not None:
|
||||
_tools["googleSearchRetrieval"] = google_search_tool
|
||||
optional_params["tools"] = [_tools]
|
||||
if tool.get("googleSearchRetrieval", None) is not None:
|
||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||
_tools = Tools(
|
||||
function_declarations=gtool_func_declarations,
|
||||
)
|
||||
if googleSearchRetrieval is not None:
|
||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||
optional_params["tools"] = [_tools] + provider_specific_tools
|
||||
if param == "tool_choice" and (
|
||||
isinstance(value, str) or isinstance(value, dict)
|
||||
):
|
||||
|
|
|
@ -677,12 +677,13 @@ def test_gemini_pro_vision_base64():
|
|||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_gemini_pro_grounding():
|
||||
@pytest.mark.parametrize("value_in_dict", [{}, {"disable_attribution": False}]) #
|
||||
def test_gemini_pro_grounding(value_in_dict):
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
||||
tools = [{"googleSearchRetrieval": {}}]
|
||||
tools = [{"googleSearchRetrieval": value_in_dict}]
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
@ -709,6 +710,10 @@ def test_gemini_pro_grounding():
|
|||
"googleSearchRetrieval"
|
||||
in mock_call.call_args.kwargs["json"]["tools"][0]
|
||||
)
|
||||
assert (
|
||||
mock_call.call_args.kwargs["json"]["tools"][0]["googleSearchRetrieval"]
|
||||
== value_in_dict
|
||||
)
|
||||
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue