forked from phoenix-oss/llama-stack-mirror
feat: Support "stop" parameter in remote:vLLM (#1715)
# What does this PR do? This adds support for "stop" parameter: https://platform.openai.com/docs/api-reference/completions/create#completions-create-stop ## Test Plan ``` tests/integration/inference/test_text_inference.py::test_text_completion_non_streaming[txt=8B-inference:completion:sanity] PASSED [ 5%] tests/integration/inference/test_text_inference.py::test_text_completion_streaming[txt=8B-inference:completion:sanity] PASSED [ 11%] tests/integration/inference/test_text_inference.py::test_text_completion_stop_sequence[txt=8B-inference:completion:stop_sequence] PASSED [ 16%] tests/integration/inference/test_text_inference.py::test_text_completion_log_probs_non_streaming[txt=8B-inference:completion:log_probs] PASSED [ 22%] tests/integration/inference/test_text_inference.py::test_text_completion_log_probs_streaming[txt=8B-inference:completion:log_probs] PASSED [ 27%] tests/integration/inference/test_text_inference.py::test_text_completion_structured_output[txt=8B-inference:completion:structured_output] PASSED [ 33%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_non_streaming[txt=8B-inference:chat_completion:non_streaming_01] PASSED [ 38%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_non_streaming[txt=8B-inference:chat_completion:non_streaming_02] PASSED [ 44%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_first_token_profiling[txt=8B-inference:chat_completion:ttft] ^TPASSED [ 50%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_streaming[txt=8B-inference:chat_completion:streaming_01] PASSED [ 55%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_streaming[txt=8B-inference:chat_completion:streaming_02] PASSED [ 61%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[txt=8B-inference:chat_completion:tool_calling] PASSED [ 66%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[txt=8B-inference:chat_completion:tool_calling] PASSED [ 72%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[txt=8B-inference:chat_completion:tool_calling] PASSED [ 77%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[txt=8B-inference:chat_completion:tool_calling] PASSED [ 83%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_structured_output[txt=8B-inference:chat_completion:structured_output] PASSED [ 88%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=8B-inference:chat_completion:tool_calling_tools_absent-True] PASSED [ 94%] tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=8B-inference:chat_completion:tool_calling_tools_absent-False] PASSED [100%] =============================================================== 18 passed, 3 warnings in 755.79s (0:12:35) =============================================================== ``` --------- Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
9ff82036f7
commit
441016bee8
6 changed files with 79 additions and 4 deletions
19
docs/_static/llama-stack-spec.html
vendored
19
docs/_static/llama-stack-spec.html
vendored
|
@ -4053,22 +4053,33 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"strategy": {
|
"strategy": {
|
||||||
"$ref": "#/components/schemas/SamplingStrategy"
|
"$ref": "#/components/schemas/SamplingStrategy",
|
||||||
|
"description": "The sampling strategy."
|
||||||
},
|
},
|
||||||
"max_tokens": {
|
"max_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"default": 0
|
"default": 0,
|
||||||
|
"description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length."
|
||||||
},
|
},
|
||||||
"repetition_penalty": {
|
"repetition_penalty": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"default": 1.0
|
"default": 1.0,
|
||||||
|
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics."
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"strategy"
|
"strategy"
|
||||||
],
|
],
|
||||||
"title": "SamplingParams"
|
"title": "SamplingParams",
|
||||||
|
"description": "Sampling parameters."
|
||||||
},
|
},
|
||||||
"SamplingStrategy": {
|
"SamplingStrategy": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
|
|
17
docs/_static/llama-stack-spec.yaml
vendored
17
docs/_static/llama-stack-spec.yaml
vendored
|
@ -2787,16 +2787,33 @@ components:
|
||||||
properties:
|
properties:
|
||||||
strategy:
|
strategy:
|
||||||
$ref: '#/components/schemas/SamplingStrategy'
|
$ref: '#/components/schemas/SamplingStrategy'
|
||||||
|
description: The sampling strategy.
|
||||||
max_tokens:
|
max_tokens:
|
||||||
type: integer
|
type: integer
|
||||||
default: 0
|
default: 0
|
||||||
|
description: >-
|
||||||
|
The maximum number of tokens that can be generated in the completion.
|
||||||
|
The token count of your prompt plus max_tokens cannot exceed the model's
|
||||||
|
context length.
|
||||||
repetition_penalty:
|
repetition_penalty:
|
||||||
type: number
|
type: number
|
||||||
default: 1.0
|
default: 1.0
|
||||||
|
description: >-
|
||||||
|
Number between -2.0 and 2.0. Positive values penalize new tokens based
|
||||||
|
on whether they appear in the text so far, increasing the model's likelihood
|
||||||
|
to talk about new topics.
|
||||||
|
stop:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Up to 4 sequences where the API will stop generating further tokens. The
|
||||||
|
returned text will not contain the stop sequence.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- strategy
|
- strategy
|
||||||
title: SamplingParams
|
title: SamplingParams
|
||||||
|
description: Sampling parameters.
|
||||||
SamplingStrategy:
|
SamplingStrategy:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/GreedySamplingStrategy'
|
- $ref: '#/components/schemas/GreedySamplingStrategy'
|
||||||
|
|
|
@ -195,10 +195,22 @@ register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SamplingParams(BaseModel):
|
class SamplingParams(BaseModel):
|
||||||
|
"""Sampling parameters.
|
||||||
|
|
||||||
|
:param strategy: The sampling strategy.
|
||||||
|
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||||
|
your prompt plus max_tokens cannot exceed the model's context length.
|
||||||
|
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||||
|
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||||
|
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
The returned text will not contain the stop sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
max_tokens: Optional[int] = 0
|
max_tokens: Optional[int] = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class CheckpointQuantizationFormat(Enum):
|
class CheckpointQuantizationFormat(Enum):
|
||||||
|
|
|
@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict:
|
||||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||||
options["repeat_penalty"] = params.repetition_penalty
|
options["repeat_penalty"] = params.repetition_penalty
|
||||||
|
|
||||||
|
if params.stop is not None:
|
||||||
|
options["stop"] = params.stop
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,33 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
|
||||||
assert len(content_str) > 10
|
assert len(content_str) > 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:completion:stop_sequence",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||||
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
|
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
|
||||||
|
if inference_provider_type != "remote::vllm":
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support 'stop' parameter yet")
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
response = client_with_models.inference.completion(
|
||||||
|
content=tc["content"],
|
||||||
|
stream=True,
|
||||||
|
model_id=text_model_id,
|
||||||
|
sampling_params={
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stop": ["1963"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
streamed_content = [chunk.delta for chunk in response]
|
||||||
|
content_str = "".join(streamed_content).lower().strip()
|
||||||
|
assert "1963" not in content_str
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
|
|
@ -10,6 +10,11 @@
|
||||||
"expected": "1963"
|
"expected": "1963"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"stop_sequence": {
|
||||||
|
"data": {
|
||||||
|
"content": "Return the exact same sentence and don't add additional words): Michael Jordan was born in the year of 1963"
|
||||||
|
}
|
||||||
|
},
|
||||||
"streaming": {
|
"streaming": {
|
||||||
"data": {
|
"data": {
|
||||||
"content": "Roses are red,"
|
"content": "Roses are red,"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue