verify recursive nature in structured outputs

This commit is contained in:
Hardik Shah 2025-02-27 17:21:32 -08:00
parent 94e2186bb8
commit 17ef47e909
3 changed files with 48 additions and 3 deletions

View file

@ -111,7 +111,8 @@
"first_name": "Michael", "first_name": "Michael",
"last_name": "Jordan", "last_name": "Jordan",
"year_of_birth": 1963, "year_of_birth": 1963,
"num_seasons_in_nba": 15 "num_seasons_in_nba": 15,
"year_for_draft": 1984
} }
} }
}, },

View file

@ -126,6 +126,40 @@ class LiteLLMOpenAIMixin(
): ):
yield chunk yield chunk
def _add_additional_properties_recursive(self, schema):
"""
Recursively add additionalProperties: False to all object schemas
"""
if isinstance(schema, dict):
# If this is an object schema
if schema.get("type") == "object":
schema["additionalProperties"] = False
# Add required field with all property keys if properties exist
if "properties" in schema and schema["properties"]:
schema["required"] = list(schema["properties"].keys())
# Handle properties within objects
if "properties" in schema:
for prop_schema in schema["properties"].values():
self._add_additional_properties_recursive(prop_schema)
# Handle anyOf/allOf/oneOf/not
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema:
for sub_schema in schema[key]:
self._add_additional_properties_recursive(sub_schema)
if "not" in schema:
self._add_additional_properties_recursive(schema["not"])
# Handle $defs/$ref
if "$defs" in schema:
for def_schema in schema["$defs"].values():
self._add_additional_properties_recursive(def_schema)
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict: async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {} input_dict = {}
@ -140,6 +174,10 @@ class LiteLLMOpenAIMixin(
name = fmt["title"] name = fmt["title"]
del fmt["title"] del fmt["title"]
fmt["additionalProperties"] = False fmt["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
input_dict["response_format"] = { input_dict["response_format"] = {
"type": "json_schema", "type": "json_schema",
"json_schema": { "json_schema": {

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
@ -342,11 +343,15 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod
], ],
) )
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
class NBAStats(BaseModel):
year_for_draft: int
num_seasons_in_nba: int
class AnswerFormat(BaseModel): class AnswerFormat(BaseModel):
first_name: str first_name: str
last_name: str last_name: str
year_of_birth: int year_of_birth: int
num_seasons_in_nba: int nba_stats: NBAStats
tc = TestCase(test_case) tc = TestCase(test_case)
@ -364,7 +369,8 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i
assert answer.first_name == expected["first_name"] assert answer.first_name == expected["first_name"]
assert answer.last_name == expected["last_name"] assert answer.last_name == expected["last_name"]
assert answer.year_of_birth == expected["year_of_birth"] assert answer.year_of_birth == expected["year_of_birth"]
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"] assert answer.nba_stats.num_seasons_in_nba == expected["num_seasons_in_nba"]
assert answer.nba_stats.year_for_draft == expected["year_for_draft"]
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])