fix(vertex_httpx.py): support passing response_schema to gemini

This commit is contained in:
Krrish Dholakia 2024-06-29 11:33:19 -07:00
parent e1f84b1bd9
commit e73e9e12bc
3 changed files with 58 additions and 2 deletions

View file

@ -356,6 +356,7 @@ class VertexGeminiConfig:
model: str, model: str,
non_default_params: dict, non_default_params: dict,
optional_params: dict, optional_params: dict,
drop_params: bool,
): ):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "temperature": if param == "temperature":
@ -375,8 +376,13 @@ class VertexGeminiConfig:
optional_params["stop_sequences"] = value optional_params["stop_sequences"] = value
if param == "max_tokens": if param == "max_tokens":
optional_params["max_output_tokens"] = value optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object": # type: ignore if param == "response_format" and isinstance(value, dict): # type: ignore
if value["type"] == "json_object":
optional_params["response_mime_type"] = "application/json" optional_params["response_mime_type"] = "application/json"
elif value["type"] == "text":
optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value:
optional_params["response_schema"] = value["response_schema"]
if param == "frequency_penalty": if param == "frequency_penalty":
optional_params["frequency_penalty"] = value optional_params["frequency_penalty"] = value
if param == "presence_penalty": if param == "presence_penalty":

View file

@ -880,6 +880,51 @@ Using this JSON schema:
mock_call.assert_called_once() mock_call.assert_called_once()
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
@pytest.mark.asyncio
async def test_gemini_pro_json_schema_httpx(provider):
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
from litellm.llms.custom_httpx.http_handler import HTTPHandler
response_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
},
"required": ["recipe_name"],
},
}
client = HTTPHandler()
with patch.object(client, "post", new=MagicMock()) as mock_call:
try:
response = completion(
model="vertex_ai_beta/gemini-1.5-pro-001",
messages=messages,
response_format={
"type": "json_object",
"response_schema": response_schema,
},
client=client,
)
except Exception as e:
pass
mock_call.assert_called_once()
print(mock_call.call_args.kwargs)
print(mock_call.call_args.kwargs["json"]["generationConfig"])
assert (
"response_schema" in mock_call.call_args.kwargs["json"]["generationConfig"]
)
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_httpx_custom_api_base(provider): async def test_gemini_pro_httpx_custom_api_base(provider):

View file

@ -2756,6 +2756,11 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif ( elif (
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models