mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
fix: Structured outputs for recursive models (#1311)
Handle recursive nature in the structured response_formats. Update test to include 1 nested model. ``` LLAMA_STACK_CONFIG=dev pytest -s -v tests/client-sdk/inference/test_text_inference.py --inference-model "openai/gpt-4o-mini" -k test_text_chat_completion_structured_output ``` --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
94e2186bb8
commit
2f7683bc5f
3 changed files with 45 additions and 3 deletions
|
@ -111,7 +111,8 @@
|
||||||
"first_name": "Michael",
|
"first_name": "Michael",
|
||||||
"last_name": "Jordan",
|
"last_name": "Jordan",
|
||||||
"year_of_birth": 1963,
|
"year_of_birth": 1963,
|
||||||
"num_seasons_in_nba": 15
|
"num_seasons_in_nba": 15,
|
||||||
|
"year_for_draft": 1984
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -126,6 +126,37 @@ class LiteLLMOpenAIMixin(
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
def _add_additional_properties_recursive(self, schema):
|
||||||
|
"""
|
||||||
|
Recursively add additionalProperties: False to all object schemas
|
||||||
|
"""
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
if schema.get("type") == "object":
|
||||||
|
schema["additionalProperties"] = False
|
||||||
|
|
||||||
|
# Add required field with all property keys if properties exist
|
||||||
|
if "properties" in schema and schema["properties"]:
|
||||||
|
schema["required"] = list(schema["properties"].keys())
|
||||||
|
|
||||||
|
if "properties" in schema:
|
||||||
|
for prop_schema in schema["properties"].values():
|
||||||
|
self._add_additional_properties_recursive(prop_schema)
|
||||||
|
|
||||||
|
for key in ["anyOf", "allOf", "oneOf"]:
|
||||||
|
if key in schema:
|
||||||
|
for sub_schema in schema[key]:
|
||||||
|
self._add_additional_properties_recursive(sub_schema)
|
||||||
|
|
||||||
|
if "not" in schema:
|
||||||
|
self._add_additional_properties_recursive(schema["not"])
|
||||||
|
|
||||||
|
# Handle $defs/$ref
|
||||||
|
if "$defs" in schema:
|
||||||
|
for def_schema in schema["$defs"].values():
|
||||||
|
self._add_additional_properties_recursive(def_schema)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
|
||||||
|
@ -140,6 +171,10 @@ class LiteLLMOpenAIMixin(
|
||||||
name = fmt["title"]
|
name = fmt["title"]
|
||||||
del fmt["title"]
|
del fmt["title"]
|
||||||
fmt["additionalProperties"] = False
|
fmt["additionalProperties"] = False
|
||||||
|
|
||||||
|
# Apply additionalProperties: False recursively to all objects
|
||||||
|
fmt = self._add_additional_properties_recursive(fmt)
|
||||||
|
|
||||||
input_dict["response_format"] = {
|
input_dict["response_format"] = {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
"json_schema": {
|
"json_schema": {
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -342,11 +343,15 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
|
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||||
|
class NBAStats(BaseModel):
|
||||||
|
year_for_draft: int
|
||||||
|
num_seasons_in_nba: int
|
||||||
|
|
||||||
class AnswerFormat(BaseModel):
|
class AnswerFormat(BaseModel):
|
||||||
first_name: str
|
first_name: str
|
||||||
last_name: str
|
last_name: str
|
||||||
year_of_birth: int
|
year_of_birth: int
|
||||||
num_seasons_in_nba: int
|
nba_stats: NBAStats
|
||||||
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
@ -364,7 +369,8 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i
|
||||||
assert answer.first_name == expected["first_name"]
|
assert answer.first_name == expected["first_name"]
|
||||||
assert answer.last_name == expected["last_name"]
|
assert answer.last_name == expected["last_name"]
|
||||||
assert answer.year_of_birth == expected["year_of_birth"]
|
assert answer.year_of_birth == expected["year_of_birth"]
|
||||||
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
|
assert answer.nba_stats.num_seasons_in_nba == expected["num_seasons_in_nba"]
|
||||||
|
assert answer.nba_stats.year_for_draft == expected["year_for_draft"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("streaming", [True, False])
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue