fix(vertex_httpx.py): fix json schema call to pass in response_mime_type=="application/json"

This commit is contained in:
Krrish Dholakia 2024-08-21 15:22:22 -07:00
parent a583b95d85
commit 664c40a4c7
2 changed files with 14 additions and 0 deletions

View file

@ -188,9 +188,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
elif value["type"] == "text": # type: ignore
optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["response_schema"] # type: ignore
elif value["type"] == "json_schema": # type: ignore
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
@ -400,9 +402,11 @@ class VertexGeminiConfig:
elif value["type"] == "text":
optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value:
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["response_schema"]
elif value["type"] == "json_schema": # type: ignore
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value

View file

@ -1558,6 +1558,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
"response_schema"
in mock_call.call_args.kwargs["json"]["generationConfig"]
)
assert (
"response_mime_type"
in mock_call.call_args.kwargs["json"]["generationConfig"]
)
assert (
mock_call.call_args.kwargs["json"]["generationConfig"][
"response_mime_type"
]
== "application/json"
)
else:
assert (
"response_schema"