From 2f7683bc5fc33192fe34533d47d47328ff522fee Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 27 Feb 2025 17:31:53 -0800 Subject: [PATCH] fix: Structured outputs for recursive models (#1311) Handle recursive nature in the structured response_formats. Update test to include 1 nested model. ``` LLAMA_STACK_CONFIG=dev pytest -s -v tests/client-sdk/inference/test_text_inference.py --inference-model "openai/gpt-4o-mini" -k test_text_chat_completion_structured_output ``` --------- Co-authored-by: Ashwin Bharambe --- .../test_cases/inference/chat_completion.json | 3 +- .../utils/inference/litellm_openai_mixin.py | 35 +++++++++++++++++++ .../inference/test_text_inference.py | 10 ++++-- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/tests/test_cases/inference/chat_completion.json b/llama_stack/providers/tests/test_cases/inference/chat_completion.json index 50f6b1c15..dcc767e4e 100644 --- a/llama_stack/providers/tests/test_cases/inference/chat_completion.json +++ b/llama_stack/providers/tests/test_cases/inference/chat_completion.json @@ -111,7 +111,8 @@ "first_name": "Michael", "last_name": "Jordan", "year_of_birth": 1963, - "num_seasons_in_nba": 15 + "num_seasons_in_nba": 15, + "year_for_draft": 1984 } } }, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index ecb6961da..ddf7f193f 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -126,6 +126,37 @@ class LiteLLMOpenAIMixin( ): yield chunk + def _add_additional_properties_recursive(self, schema): + """ + Recursively add additionalProperties: False to all object schemas + """ + if isinstance(schema, dict): + if schema.get("type") == "object": + schema["additionalProperties"] = False + + # Add required field with all property keys if properties exist + if "properties" in schema and schema["properties"]: + schema["required"] = list(schema["properties"].keys()) + + if "properties" in schema: + for prop_schema in schema["properties"].values(): + self._add_additional_properties_recursive(prop_schema) + + for key in ["anyOf", "allOf", "oneOf"]: + if key in schema: + for sub_schema in schema[key]: + self._add_additional_properties_recursive(sub_schema) + + if "not" in schema: + self._add_additional_properties_recursive(schema["not"]) + + # Handle $defs/$ref + if "$defs" in schema: + for def_schema in schema["$defs"].values(): + self._add_additional_properties_recursive(def_schema) + + return schema + async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {} @@ -140,6 +171,10 @@ class LiteLLMOpenAIMixin( name = fmt["title"] del fmt["title"] fmt["additionalProperties"] = False + + # Apply additionalProperties: False recursively to all objects + fmt = self._add_additional_properties_recursive(fmt) + input_dict["response_format"] = { "type": "json_schema", "json_schema": { diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 577d995ad..7850d2d57 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + import pytest from pydantic import BaseModel @@ -342,11 +343,15 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod ], ) def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): + class NBAStats(BaseModel): + year_for_draft: int + num_seasons_in_nba: int + class AnswerFormat(BaseModel): first_name: str last_name: str year_of_birth: int - num_seasons_in_nba: int + nba_stats: NBAStats tc = TestCase(test_case) @@ -364,7 +369,8 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i assert answer.first_name == expected["first_name"] assert answer.last_name == expected["last_name"] assert answer.year_of_birth == expected["year_of_birth"] - assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"] + assert answer.nba_stats.num_seasons_in_nba == expected["num_seasons_in_nba"] + assert answer.nba_stats.year_for_draft == expected["year_for_draft"] @pytest.mark.parametrize("streaming", [True, False])