diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index a6dcd3daa..940016ecb 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -356,6 +356,7 @@ class VertexGeminiConfig: model: str, non_default_params: dict, optional_params: dict, + drop_params: bool, ): for param, value in non_default_params.items(): if param == "temperature": @@ -375,8 +376,13 @@ class VertexGeminiConfig: optional_params["stop_sequences"] = value if param == "max_tokens": optional_params["max_output_tokens"] = value - if param == "response_format" and value["type"] == "json_object": # type: ignore - optional_params["response_mime_type"] = "application/json" + if param == "response_format" and isinstance(value, dict): # type: ignore + if value["type"] == "json_object": + optional_params["response_mime_type"] = "application/json" + elif value["type"] == "text": + optional_params["response_mime_type"] = "text/plain" + if "response_schema" in value: + optional_params["response_schema"] = value["response_schema"] if param == "frequency_penalty": optional_params["frequency_penalty"] = value if param == "presence_penalty": diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 6de3e11b8..e6f2634f4 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -880,6 +880,51 @@ Using this JSON schema: mock_call.assert_called_once() +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.asyncio +async def test_gemini_pro_json_schema_httpx(provider): + load_vertex_ai_credentials() + litellm.set_verbose = True + messages = [{"role": "user", "content": "List 5 cookie recipes"}] + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + response_schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + } + + client = HTTPHandler() + + with patch.object(client, "post", new=MagicMock()) as mock_call: + try: + response = completion( + model="vertex_ai_beta/gemini-1.5-pro-001", + messages=messages, + response_format={ + "type": "json_object", + "response_schema": response_schema, + }, + client=client, + ) + except Exception as e: + pass + + mock_call.assert_called_once() + print(mock_call.call_args.kwargs) + print(mock_call.call_args.kwargs["json"]["generationConfig"]) + assert ( + "response_schema" in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + + @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.asyncio async def test_gemini_pro_httpx_custom_api_base(provider): diff --git a/litellm/utils.py b/litellm/utils.py index fef4976ca..dc2bcb25a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2756,6 +2756,11 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif ( custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models