actually test strutured output in completion

This commit is contained in:
Dinesh Yeduguru 2024-10-24 14:44:31 -07:00
parent 3796dbd4a5
commit 9bf1388429
3 changed files with 35 additions and 26 deletions

View file

@ -82,6 +82,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model=model, model=model,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )

View file

@ -185,33 +185,30 @@ async def test_completions_structured_output(inference_settings):
"Other inference providers don't support structured output in completions yet" "Other inference providers don't support structured output in completions yet"
) )
class Animals(BaseModel): class Output(BaseModel):
location: str name: str
activity: str year_born: str
animals_seen: conint(ge=1, le=5) # Constrained integer type year_retired: str
animals: List[str]
user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park" user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion( response = await inference_impl.completion(
content=f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.schema()}", content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ",
stream=False, stream=False,
model=params["model"], model=params["model"],
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
response_format=JsonResponseFormat( response_format=JsonResponseFormat(
schema=Animals.model_json_schema(), schema=Output.model_json_schema(),
), ),
**inference_settings["common_params"], )
) assert isinstance(response, CompletionResponse)
assert isinstance(response, CompletionResponse) assert isinstance(response.content, str)
assert isinstance(response.completion_message.content, str)
answer = Animals.parse_raw(response.completion_message.content) answer = Output.parse_raw(response.content)
assert answer.activity == "bike ride" assert "Michael Jordan" in answer.name
assert answer.animals == ["puppy", "cat", "raccoon"] assert answer.year_born == "1963"
assert answer.animals_seen == 3 assert answer.year_retired == "2003"
assert answer.location == "park"
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -64,7 +64,18 @@ def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse: ) -> CompletionResponse:
choice = response.choices[0] choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")],
)
# drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")],
)
return CompletionResponse( return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason), stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text, content=choice.text,