forked from phoenix/litellm-mirror
fix(vertex_httpx.py): google search grounding fix
This commit is contained in:
parent
385da04d72
commit
82ca7af6df
3 changed files with 66 additions and 9 deletions
|
@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
## 🆕 `vertex_ai_beta/` route
|
## 🆕 `vertex_ai_beta/` route
|
||||||
|
|
||||||
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk).
|
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk). This implementation uses [VertexAI's REST API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#syntax).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -377,6 +377,54 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
|
||||||
|
|
||||||
|
|
||||||
|
If this was your initial VertexAI Grounding code,
|
||||||
|
|
||||||
|
```python
|
||||||
|
import vertexai
|
||||||
|
|
||||||
|
vertexai.init(project=project_id, location="us-central1")
|
||||||
|
|
||||||
|
model = GenerativeModel("gemini-1.5-flash-001")
|
||||||
|
|
||||||
|
# Use Google Search for grounding
|
||||||
|
tool = Tool.from_google_search_retrieval(grounding.GoogleSearchRetrieval(disable_attributon=False))
|
||||||
|
|
||||||
|
prompt = "When is the next total solar eclipse in US?"
|
||||||
|
response = model.generate_content(
|
||||||
|
prompt,
|
||||||
|
tools=[tool],
|
||||||
|
generation_config=GenerationConfig(
|
||||||
|
temperature=0.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
then, this is what it looks like now
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
|
||||||
|
# !gcloud auth application-default login - run this to add vertex credentials to your env
|
||||||
|
|
||||||
|
tools = [{"googleSearchRetrieval": {"disable_attributon": False}}] # 👈 ADD GOOGLE SEARCH
|
||||||
|
|
||||||
|
resp = litellm.completion(
|
||||||
|
model="vertex_ai_beta/gemini-1.0-pro-001",
|
||||||
|
messages=[{"role": "user", "content": "Who won the world cup?"}],
|
||||||
|
tools=tools,
|
||||||
|
vertex_project="project-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(resp)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Pre-requisites
|
## Pre-requisites
|
||||||
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
||||||
* Authentication:
|
* Authentication:
|
||||||
|
|
|
@ -396,7 +396,8 @@ class VertexGeminiConfig:
|
||||||
optional_params["presence_penalty"] = value
|
optional_params["presence_penalty"] = value
|
||||||
if param == "tools" and isinstance(value, list):
|
if param == "tools" and isinstance(value, list):
|
||||||
gtool_func_declarations = []
|
gtool_func_declarations = []
|
||||||
google_search_tool: Optional[dict] = None
|
googleSearchRetrieval: Optional[dict] = None
|
||||||
|
provider_specific_tools: List[dict] = []
|
||||||
for tool in value:
|
for tool in value:
|
||||||
# check if grounding
|
# check if grounding
|
||||||
try:
|
try:
|
||||||
|
@ -411,11 +412,14 @@ class VertexGeminiConfig:
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(
|
||||||
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
||||||
)
|
)
|
||||||
google_search_tool = tool
|
if tool.get("googleSearchRetrieval", None) is not None:
|
||||||
_tools = Tools(function_declarations=gtool_func_declarations)
|
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||||
if google_search_tool is not None:
|
_tools = Tools(
|
||||||
_tools["googleSearchRetrieval"] = google_search_tool
|
function_declarations=gtool_func_declarations,
|
||||||
optional_params["tools"] = [_tools]
|
)
|
||||||
|
if googleSearchRetrieval is not None:
|
||||||
|
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||||
|
optional_params["tools"] = [_tools] + provider_specific_tools
|
||||||
if param == "tool_choice" and (
|
if param == "tool_choice" and (
|
||||||
isinstance(value, str) or isinstance(value, dict)
|
isinstance(value, str) or isinstance(value, dict)
|
||||||
):
|
):
|
||||||
|
|
|
@ -677,12 +677,13 @@ def test_gemini_pro_vision_base64():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_pro_grounding():
|
@pytest.mark.parametrize("value_in_dict", [{}, {"disable_attribution": False}]) #
|
||||||
|
def test_gemini_pro_grounding(value_in_dict):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
tools = [{"googleSearchRetrieval": {}}]
|
tools = [{"googleSearchRetrieval": value_in_dict}]
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
@ -709,6 +710,10 @@ def test_gemini_pro_grounding():
|
||||||
"googleSearchRetrieval"
|
"googleSearchRetrieval"
|
||||||
in mock_call.call_args.kwargs["json"]["tools"][0]
|
in mock_call.call_args.kwargs["json"]["tools"][0]
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
mock_call.call_args.kwargs["json"]["tools"][0]["googleSearchRetrieval"]
|
||||||
|
== value_in_dict
|
||||||
|
)
|
||||||
|
|
||||||
except litellm.InternalServerError:
|
except litellm.InternalServerError:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue