mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
actually test strutured output in completion
This commit is contained in:
parent
3796dbd4a5
commit
9bf1388429
3 changed files with 35 additions and 26 deletions
|
@ -82,6 +82,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model=model,
|
model=model,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -185,33 +185,30 @@ async def test_completions_structured_output(inference_settings):
|
||||||
"Other inference providers don't support structured output in completions yet"
|
"Other inference providers don't support structured output in completions yet"
|
||||||
)
|
)
|
||||||
|
|
||||||
class Animals(BaseModel):
|
class Output(BaseModel):
|
||||||
location: str
|
name: str
|
||||||
activity: str
|
year_born: str
|
||||||
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
year_retired: str
|
||||||
animals: List[str]
|
|
||||||
|
|
||||||
user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park"
|
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content=f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.schema()}",
|
content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ",
|
||||||
stream=False,
|
stream=False,
|
||||||
model=params["model"],
|
model=params["model"],
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
response_format=JsonResponseFormat(
|
response_format=JsonResponseFormat(
|
||||||
schema=Animals.model_json_schema(),
|
schema=Output.model_json_schema(),
|
||||||
),
|
),
|
||||||
**inference_settings["common_params"],
|
)
|
||||||
)
|
assert isinstance(response, CompletionResponse)
|
||||||
assert isinstance(response, CompletionResponse)
|
assert isinstance(response.content, str)
|
||||||
assert isinstance(response.completion_message.content, str)
|
|
||||||
|
|
||||||
answer = Animals.parse_raw(response.completion_message.content)
|
answer = Output.parse_raw(response.content)
|
||||||
assert answer.activity == "bike ride"
|
assert "Michael Jordan" in answer.name
|
||||||
assert answer.animals == ["puppy", "cat", "raccoon"]
|
assert answer.year_born == "1963"
|
||||||
assert answer.animals_seen == 3
|
assert answer.year_retired == "2003"
|
||||||
assert answer.location == "park"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -64,7 +64,18 @@ def process_completion_response(
|
||||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||||
|
if choice.text.endswith("<|eot_id|>"):
|
||||||
|
return CompletionResponse(
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
content=choice.text[: -len("<|eot_id|>")],
|
||||||
|
)
|
||||||
|
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||||
|
if choice.text.endswith("<|eom_id|>"):
|
||||||
|
return CompletionResponse(
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
content=choice.text[: -len("<|eom_id|>")],
|
||||||
|
)
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=get_stop_reason(choice.finish_reason),
|
stop_reason=get_stop_reason(choice.finish_reason),
|
||||||
content=choice.text,
|
content=choice.text,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue