diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 1901e178c..023744b66 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem'; ## 🆕 `vertex_ai_beta/` route -New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk). +New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk). This implementation uses [VertexAI's REST API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#syntax). ```python from litellm import completion @@ -377,6 +377,54 @@ curl http://0.0.0.0:4000/v1/chat/completions \ +#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)** + + +If this was your initial VertexAI Grounding code, + +```python +import vertexai + +vertexai.init(project=project_id, location="us-central1") + +model = GenerativeModel("gemini-1.5-flash-001") + +# Use Google Search for grounding +tool = Tool.from_google_search_retrieval(grounding.GoogleSearchRetrieval(disable_attributon=False)) + +prompt = "When is the next total solar eclipse in US?" +response = model.generate_content( + prompt, + tools=[tool], + generation_config=GenerationConfig( + temperature=0.0, + ), +) + +print(response) +``` + +then, this is what it looks like now + +```python +from litellm import completion + + +# !gcloud auth application-default login - run this to add vertex credentials to your env + +tools = [{"googleSearchRetrieval": {"disable_attributon": False}}] # 👈 ADD GOOGLE SEARCH + +resp = litellm.completion( + model="vertex_ai_beta/gemini-1.0-pro-001", + messages=[{"role": "user", "content": "Who won the world cup?"}], + tools=tools, + vertex_project="project-id" + ) + +print(resp) +``` + + ## Pre-requisites * `pip install google-cloud-aiplatform` (pre-installed on proxy docker image) * Authentication: diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index d73c318e6..9f72a9296 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -396,7 +396,8 @@ class VertexGeminiConfig: optional_params["presence_penalty"] = value if param == "tools" and isinstance(value, list): gtool_func_declarations = [] - google_search_tool: Optional[dict] = None + googleSearchRetrieval: Optional[dict] = None + provider_specific_tools: List[dict] = [] for tool in value: # check if grounding try: @@ -411,11 +412,14 @@ class VertexGeminiConfig: verbose_logger.warning( "Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request." ) - google_search_tool = tool - _tools = Tools(function_declarations=gtool_func_declarations) - if google_search_tool is not None: - _tools["googleSearchRetrieval"] = google_search_tool - optional_params["tools"] = [_tools] + if tool.get("googleSearchRetrieval", None) is not None: + googleSearchRetrieval = tool["googleSearchRetrieval"] + _tools = Tools( + function_declarations=gtool_func_declarations, + ) + if googleSearchRetrieval is not None: + _tools["googleSearchRetrieval"] = googleSearchRetrieval + optional_params["tools"] = [_tools] + provider_specific_tools if param == "tool_choice" and ( isinstance(value, str) or isinstance(value, dict) ): diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 211f093d9..d95f152fd 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -677,12 +677,13 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -def test_gemini_pro_grounding(): +@pytest.mark.parametrize("value_in_dict", [{}, {"disable_attribution": False}]) # +def test_gemini_pro_grounding(value_in_dict): try: load_vertex_ai_credentials() litellm.set_verbose = True - tools = [{"googleSearchRetrieval": {}}] + tools = [{"googleSearchRetrieval": value_in_dict}] litellm.set_verbose = True @@ -709,6 +710,10 @@ def test_gemini_pro_grounding(): "googleSearchRetrieval" in mock_call.call_args.kwargs["json"]["tools"][0] ) + assert ( + mock_call.call_args.kwargs["json"]["tools"][0]["googleSearchRetrieval"] + == value_in_dict + ) except litellm.InternalServerError: pass