From 441016bee8c6b3b7ce89e7809a903d3343b705e2 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Mon, 24 Mar 2025 15:42:55 -0400 Subject: [PATCH] feat: Support "stop" parameter in remote:vLLM (#1715) # What does this PR do? This adds support for "stop" parameter: https://platform.openai.com/docs/api-reference/completions/create#completions-create-stop ## Test Plan ``` tests/integration/inference/test_text_inference.py::test_text_completion_non_streaming[txt=8B-inference:completion:sanity] PASSED [ 5%] tests/integration/inference/test_text_inference.py::test_text_completion_streaming[txt=8B-inference:completion:sanity] PASSED [ 11%] tests/integration/inference/test_text_inference.py::test_text_completion_stop_sequence[txt=8B-inference:completion:stop_sequence] PASSED [ 16%] tests/integration/inference/test_text_inference.py::test_text_completion_log_probs_non_streaming[txt=8B-inference:completion:log_probs] PASSED [ 22%] tests/integration/inference/test_text_inference.py::test_text_completion_log_probs_streaming[txt=8B-inference:completion:log_probs] PASSED [ 27%] tests/integration/inference/test_text_inference.py::test_text_completion_structured_output[txt=8B-inference:completion:structured_output] PASSED [ 33%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_non_streaming[txt=8B-inference:chat_completion:non_streaming_01] PASSED [ 38%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_non_streaming[txt=8B-inference:chat_completion:non_streaming_02] PASSED [ 44%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_first_token_profiling[txt=8B-inference:chat_completion:ttft] ^TPASSED [ 50%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_streaming[txt=8B-inference:chat_completion:streaming_01] PASSED [ 55%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_streaming[txt=8B-inference:chat_completion:streaming_02] PASSED [ 61%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[txt=8B-inference:chat_completion:tool_calling] PASSED [ 66%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[txt=8B-inference:chat_completion:tool_calling] PASSED [ 72%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[txt=8B-inference:chat_completion:tool_calling] PASSED [ 77%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[txt=8B-inference:chat_completion:tool_calling] PASSED [ 83%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_structured_output[txt=8B-inference:chat_completion:structured_output] PASSED [ 88%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=8B-inference:chat_completion:tool_calling_tools_absent-True] PASSED [ 94%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=8B-inference:chat_completion:tool_calling_tools_absent-False] PASSED [100%] =============================================================== 18 passed, 3 warnings in 755.79s (0:12:35) =============================================================== ``` --------- Signed-off-by: Yuan Tang --- docs/_static/llama-stack-spec.html | 19 ++++++++++--- docs/_static/llama-stack-spec.yaml | 17 ++++++++++++ llama_stack/models/llama/datatypes.py | 12 +++++++++ .../utils/inference/openai_compat.py | 3 +++ .../inference/test_text_inference.py | 27 +++++++++++++++++++ .../test_cases/inference/completion.json | 5 ++++ 6 files changed, 79 insertions(+), 4 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 8a46a89ad..64b06e901 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4053,22 +4053,33 @@ "type": "object", "properties": { "strategy": { - "$ref": "#/components/schemas/SamplingStrategy" + "$ref": "#/components/schemas/SamplingStrategy", + "description": "The sampling strategy." }, "max_tokens": { "type": "integer", - "default": 0 + "default": 0, + "description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length." }, "repetition_penalty": { "type": "number", - "default": 1.0 + "default": 1.0, + "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics." + }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence." } }, "additionalProperties": false, "required": [ "strategy" ], - "title": "SamplingParams" + "title": "SamplingParams", + "description": "Sampling parameters." }, "SamplingStrategy": { "oneOf": [ diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0b8f90490..78de9afef 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2787,16 +2787,33 @@ components: properties: strategy: $ref: '#/components/schemas/SamplingStrategy' + description: The sampling strategy. max_tokens: type: integer default: 0 + description: >- + The maximum number of tokens that can be generated in the completion. + The token count of your prompt plus max_tokens cannot exceed the model's + context length. repetition_penalty: type: number default: 1.0 + description: >- + Number between -2.0 and 2.0. Positive values penalize new tokens based + on whether they appear in the text so far, increasing the model's likelihood + to talk about new topics. + stop: + type: array + items: + type: string + description: >- + Up to 4 sequences where the API will stop generating further tokens. The + returned text will not contain the stop sequence. additionalProperties: false required: - strategy title: SamplingParams + description: Sampling parameters. SamplingStrategy: oneOf: - $ref: '#/components/schemas/GreedySamplingStrategy' diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index f762eb50f..fcbe44b07 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -195,10 +195,22 @@ register_schema(SamplingStrategy, name="SamplingStrategy") @json_schema_type class SamplingParams(BaseModel): + """Sampling parameters. + + :param strategy: The sampling strategy. + :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of + your prompt plus max_tokens cannot exceed the model's context length. + :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens + based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + :param stop: Up to 4 sequences where the API will stop generating further tokens. + The returned text will not contain the stop sequence. + """ + strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) max_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 + stop: Optional[List[str]] = None class CheckpointQuantizationFormat(Enum): diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index b264c7312..07976e811 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict: if params.repetition_penalty is not None and params.repetition_penalty != 1.0: options["repeat_penalty"] = params.repetition_penalty + if params.stop is not None: + options["stop"] = params.stop + return options diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index c9649df60..f558254e5 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -99,6 +99,33 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case) assert len(content_str) > 10 +@pytest.mark.parametrize( + "test_case", + [ + "inference:completion:stop_sequence", + ], +) +def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) + # This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771 + if inference_provider_type != "remote::vllm": + pytest.xfail(f"{inference_provider_type} doesn't support 'stop' parameter yet") + tc = TestCase(test_case) + + response = client_with_models.inference.completion( + content=tc["content"], + stream=True, + model_id=text_model_id, + sampling_params={ + "max_tokens": 50, + "stop": ["1963"], + }, + ) + streamed_content = [chunk.delta for chunk in response] + content_str = "".join(streamed_content).lower().strip() + assert "1963" not in content_str + + @pytest.mark.parametrize( "test_case", [ diff --git a/tests/integration/test_cases/inference/completion.json b/tests/integration/test_cases/inference/completion.json index a568ecdc9..06abbdc8b 100644 --- a/tests/integration/test_cases/inference/completion.json +++ b/tests/integration/test_cases/inference/completion.json @@ -10,6 +10,11 @@ "expected": "1963" } }, + "stop_sequence": { + "data": { + "content": "Return the exact same sentence and don't add additional words): Michael Jordan was born in the year of 1963" + } + }, "streaming": { "data": { "content": "Roses are red,"