diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 8a1aadd33..b98a86c8f 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -171,6 +171,49 @@ async def test_completion(inference_settings): assert last.stop_reason == StopReason.out_of_tokens +@pytest.mark.asyncio +async def test_completions_structured_output(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::tgi", + ): + pytest.skip( + "Other inference providers don't support structured output in completions yet" + ) + + class Animals(BaseModel): + location: str + activity: str + animals_seen: conint(ge=1, le=5) # Constrained integer type + animals: List[str] + + user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park" + response = await inference_impl.completion( + content=f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.schema()}", + stream=False, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + response_format=JsonResponseFormat( + schema=Animals.model_json_schema(), + ), + **inference_settings["common_params"], + ) + assert isinstance(response, CompletionResponse) + assert isinstance(response.completion_message.content, str) + + answer = Animals.parse_raw(response.completion_message.content) + assert answer.activity == "bike ride" + assert answer.animals == ["puppy", "cat", "raccoon"] + assert answer.animals_seen == 3 + assert answer.location == "park" + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"]