add test for structured output

This commit is contained in:
Dinesh Yeduguru 2024-10-23 20:44:49 -07:00
parent 4a073fcee5
commit 3796dbd4a5

View file

@ -171,6 +171,49 @@ async def test_completion(inference_settings):
assert last.stop_reason == StopReason.out_of_tokens assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_completions_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Animals(BaseModel):
location: str
activity: str
animals_seen: conint(ge=1, le=5) # Constrained integer type
animals: List[str]
user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park"
response = await inference_impl.completion(
content=f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.schema()}",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonResponseFormat(
schema=Animals.model_json_schema(),
),
**inference_settings["common_params"],
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.completion_message.content, str)
answer = Animals.parse_raw(response.completion_message.content)
assert answer.activity == "bike ride"
assert answer.animals == ["puppy", "cat", "raccoon"]
assert answer.animals_seen == 3
assert answer.location == "park"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages): async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"] inference_impl = inference_settings["impl"]