From 664c40a4c7ecd66ba4c42248ff02f5e3bb103e9c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 21 Aug 2024 15:22:22 -0700 Subject: [PATCH] fix(vertex_httpx.py): fix json schema call to pass in response_mime_type=="application/json" --- litellm/llms/vertex_httpx.py | 4 ++++ litellm/tests/test_amazing_vertex_completion.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 1b0ef52bcd..8fc67c0c2f 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -188,9 +188,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty elif value["type"] == "text": # type: ignore optional_params["response_mime_type"] = "text/plain" if "response_schema" in value: # type: ignore + optional_params["response_mime_type"] = "application/json" optional_params["response_schema"] = value["response_schema"] # type: ignore elif value["type"] == "json_schema": # type: ignore if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore + optional_params["response_mime_type"] = "application/json" optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore if param == "tools" and isinstance(value, list): gtool_func_declarations = [] @@ -400,9 +402,11 @@ class VertexGeminiConfig: elif value["type"] == "text": optional_params["response_mime_type"] = "text/plain" if "response_schema" in value: + optional_params["response_mime_type"] = "application/json" optional_params["response_schema"] = value["response_schema"] elif value["type"] == "json_schema": # type: ignore if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore + optional_params["response_mime_type"] = "application/json" optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore if param == "frequency_penalty": optional_params["frequency_penalty"] = value diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index fca4f1ee55..5e61e4f525 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1558,6 +1558,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema( "response_schema" in mock_call.call_args.kwargs["json"]["generationConfig"] ) + assert ( + "response_mime_type" + in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + assert ( + mock_call.call_args.kwargs["json"]["generationConfig"][ + "response_mime_type" + ] + == "application/json" + ) else: assert ( "response_schema"