From 9bf13884296acac4e31b9008152e0ca3ad218917 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 24 Oct 2024 14:44:31 -0700 Subject: [PATCH] actually test strutured output in completion --- .../providers/adapters/inference/tgi/tgi.py | 1 + .../tests/inference/test_inference.py | 47 +++++++++---------- .../utils/inference/openai_compat.py | 13 ++++- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 25a6759b3..a7fa6ba00 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -82,6 +82,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model=model, content=content, sampling_params=sampling_params, + response_format=response_format, stream=stream, logprobs=logprobs, ) diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index b98a86c8f..77e4e0fc3 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -185,33 +185,30 @@ async def test_completions_structured_output(inference_settings): "Other inference providers don't support structured output in completions yet" ) - class Animals(BaseModel): - location: str - activity: str - animals_seen: conint(ge=1, le=5) # Constrained integer type - animals: List[str] + class Output(BaseModel): + name: str + year_born: str + year_retired: str - user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park" - response = await inference_impl.completion( - content=f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.schema()}", - stream=False, - model=params["model"], - sampling_params=SamplingParams( - max_tokens=50, - ), - response_format=JsonResponseFormat( - schema=Animals.model_json_schema(), - ), - **inference_settings["common_params"], - ) - assert isinstance(response, CompletionResponse) - assert isinstance(response.completion_message.content, str) + user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." + response = await inference_impl.completion( + content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ", + stream=False, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + response_format=JsonResponseFormat( + schema=Output.model_json_schema(), + ), + ) + assert isinstance(response, CompletionResponse) + assert isinstance(response.content, str) - answer = Animals.parse_raw(response.completion_message.content) - assert answer.activity == "bike ride" - assert answer.animals == ["puppy", "cat", "raccoon"] - assert answer.animals_seen == 3 - assert answer.location == "park" + answer = Output.parse_raw(response.content) + assert "Michael Jordan" in answer.name + assert answer.year_born == "1963" + assert answer.year_retired == "2003" @pytest.mark.asyncio diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 5a5ddbb50..086227c73 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -64,7 +64,18 @@ def process_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> CompletionResponse: choice = response.choices[0] - + # drop suffix if present and return stop reason as end of turn + if choice.text.endswith("<|eot_id|>"): + return CompletionResponse( + stop_reason=StopReason.end_of_turn, + content=choice.text[: -len("<|eot_id|>")], + ) + # drop suffix if present and return stop reason as end of message + if choice.text.endswith("<|eom_id|>"): + return CompletionResponse( + stop_reason=StopReason.end_of_message, + content=choice.text[: -len("<|eom_id|>")], + ) return CompletionResponse( stop_reason=get_stop_reason(choice.finish_reason), content=choice.text,