From a1da09e1663f40dcfe34cecf0ecb7fe6af593831 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Wed, 19 Mar 2025 22:41:34 -0400 Subject: [PATCH] feat: Support "stop" parameter in remote:vLLM Signed-off-by: Yuan Tang --- llama_stack/models/llama/datatypes.py | 1 + .../utils/inference/openai_compat.py | 3 +++ .../inference/test_text_inference.py | 24 +++++++++++++++++++ .../test_cases/inference/completion.json | 6 +++++ 4 files changed, 34 insertions(+) diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index 9842d7980..4f5a6e9ef 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -201,6 +201,7 @@ class SamplingParams(BaseModel): max_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 + stop: Optional[List[str]] = None class CheckpointQuantizationFormat(Enum): diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index b264c7312..07976e811 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict: if params.repetition_penalty is not None and params.repetition_penalty != 1.0: options["repeat_penalty"] = params.repetition_penalty + if params.stop is not None: + options["stop"] = params.stop + return options diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index c9649df60..de26076ff 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -99,6 +99,30 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case) assert len(content_str) > 10 +@pytest.mark.parametrize( + "test_case", + [ + "inference:completion:stop_sequence", + ], +) +def test_text_completion_stop_sequence(client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) + tc = TestCase(test_case) + + response = client_with_models.inference.completion( + content=tc["content"], + stream=True, + model_id=text_model_id, + sampling_params={ + "max_tokens": 50, + "stop": ["1963"], + }, + ) + streamed_content = [chunk.delta for chunk in response] + content_str = "".join(streamed_content).lower().strip() + assert "1963" not in content_str + + @pytest.mark.parametrize( "test_case", [ diff --git a/tests/integration/test_cases/inference/completion.json b/tests/integration/test_cases/inference/completion.json index a568ecdc9..12b5f32de 100644 --- a/tests/integration/test_cases/inference/completion.json +++ b/tests/integration/test_cases/inference/completion.json @@ -10,6 +10,12 @@ "expected": "1963" } }, + "stop_sequence": { + "data": { + "content": "Return the exact same sentence: Michael Jordan was born in 1963", + "expected": "Michael Jordan was born in" + } + }, "streaming": { "data": { "content": "Roses are red,"