fix(utils.py): only filter additional properties if gemini/vertex ai

This commit is contained in:
Krrish Dholakia 2024-08-23 14:22:36 -07:00
parent 92e5cd113d
commit 3007f0344d
2 changed files with 20 additions and 8 deletions

View file

@ -501,10 +501,14 @@ def test_vertex_safety_settings(provider):
assert len(optional_params) == 1 assert len(optional_params) == 1
def test_parse_additional_properties_json_schema(): @pytest.mark.parametrize(
"model, provider, expectedAddProp",
[("gemini-1.5-pro", "vertex_ai_beta", False), ("gpt-3.5-turbo", "openai", True)],
)
def test_parse_additional_properties_json_schema(model, provider, expectedAddProp):
optional_params = get_optional_params( optional_params = get_optional_params(
model="gemini-1.5-pro", model=model,
custom_llm_provider="vertex_ai_beta", custom_llm_provider=provider,
response_format={ response_format={
"type": "json_schema", "type": "json_schema",
"json_schema": { "json_schema": {
@ -535,4 +539,9 @@ def test_parse_additional_properties_json_schema():
) )
print(optional_params) print(optional_params)
assert "additionalProperties" not in optional_params["response_schema"]
if provider == "vertex_ai_beta":
schema = optional_params["response_schema"]
elif provider == "openai":
schema = optional_params["response_format"]["json_schema"]["schema"]
assert ("additionalProperties" in schema) == expectedAddProp

View file

@ -2893,10 +2893,13 @@ def get_optional_params(
response_format=non_default_params["response_format"] response_format=non_default_params["response_format"]
) )
# # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240 # # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240
if ( if non_default_params["response_format"].get("json_schema", {}).get(
non_default_params["response_format"].get("json_schema", {}).get("schema") "schema"
is not None ) is not None and custom_llm_provider in [
): "gemini",
"vertex_ai",
"vertex_ai_beta",
]:
old_schema = copy.deepcopy( old_schema = copy.deepcopy(
non_default_params["response_format"] non_default_params["response_format"]
.get("json_schema", {}) .get("json_schema", {})