diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 8a46a89ad..64b06e901 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4053,22 +4053,33 @@
"type": "object",
"properties": {
"strategy": {
- "$ref": "#/components/schemas/SamplingStrategy"
+ "$ref": "#/components/schemas/SamplingStrategy",
+ "description": "The sampling strategy."
},
"max_tokens": {
"type": "integer",
- "default": 0
+ "default": 0,
+ "description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length."
},
"repetition_penalty": {
"type": "number",
- "default": 1.0
+ "default": 1.0,
+ "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics."
+ },
+ "stop": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ },
+ "description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
}
},
"additionalProperties": false,
"required": [
"strategy"
],
- "title": "SamplingParams"
+ "title": "SamplingParams",
+ "description": "Sampling parameters."
},
"SamplingStrategy": {
"oneOf": [
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 0b8f90490..78de9afef 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2787,16 +2787,33 @@ components:
properties:
strategy:
$ref: '#/components/schemas/SamplingStrategy'
+ description: The sampling strategy.
max_tokens:
type: integer
default: 0
+ description: >-
+ The maximum number of tokens that can be generated in the completion.
+ The token count of your prompt plus max_tokens cannot exceed the model's
+ context length.
repetition_penalty:
type: number
default: 1.0
+ description: >-
+ Number between -2.0 and 2.0. Positive values penalize new tokens based
+ on whether they appear in the text so far, increasing the model's likelihood
+ to talk about new topics.
+ stop:
+ type: array
+ items:
+ type: string
+ description: >-
+ Up to 4 sequences where the API will stop generating further tokens. The
+ returned text will not contain the stop sequence.
additionalProperties: false
required:
- strategy
title: SamplingParams
+ description: Sampling parameters.
SamplingStrategy:
oneOf:
- $ref: '#/components/schemas/GreedySamplingStrategy'
diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py
index f762eb50f..fcbe44b07 100644
--- a/llama_stack/models/llama/datatypes.py
+++ b/llama_stack/models/llama/datatypes.py
@@ -195,10 +195,22 @@ register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
+ """Sampling parameters.
+
+ :param strategy: The sampling strategy.
+ :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
+ your prompt plus max_tokens cannot exceed the model's context length.
+ :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
+ based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
+ :param stop: Up to 4 sequences where the API will stop generating further tokens.
+ The returned text will not contain the stop sequence.
+ """
+
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
+ stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index b264c7312..07976e811 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict:
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
+ if params.stop is not None:
+ options["stop"] = params.stop
+
return options
diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py
index c9649df60..f558254e5 100644
--- a/tests/integration/inference/test_text_inference.py
+++ b/tests/integration/inference/test_text_inference.py
@@ -99,6 +99,33 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
assert len(content_str) > 10
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ "inference:completion:stop_sequence",
+ ],
+)
+def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
+ skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
+ # This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
+ if inference_provider_type != "remote::vllm":
+ pytest.xfail(f"{inference_provider_type} doesn't support 'stop' parameter yet")
+ tc = TestCase(test_case)
+
+ response = client_with_models.inference.completion(
+ content=tc["content"],
+ stream=True,
+ model_id=text_model_id,
+ sampling_params={
+ "max_tokens": 50,
+ "stop": ["1963"],
+ },
+ )
+ streamed_content = [chunk.delta for chunk in response]
+ content_str = "".join(streamed_content).lower().strip()
+ assert "1963" not in content_str
+
+
@pytest.mark.parametrize(
"test_case",
[
diff --git a/tests/integration/test_cases/inference/completion.json b/tests/integration/test_cases/inference/completion.json
index a568ecdc9..06abbdc8b 100644
--- a/tests/integration/test_cases/inference/completion.json
+++ b/tests/integration/test_cases/inference/completion.json
@@ -10,6 +10,11 @@
"expected": "1963"
}
},
+ "stop_sequence": {
+ "data": {
+ "content": "Return the exact same sentence and don't add additional words): Michael Jordan was born in the year of 1963"
+ }
+ },
"streaming": {
"data": {
"content": "Roses are red,"