Merge branch 'main' into add-watsonx-inference-adapter

This commit is contained in:
Sajikumar JS 2025-03-25 09:34:48 +05:30
commit 4b53171139
6 changed files with 79 additions and 4 deletions

View file

@ -4053,22 +4053,33 @@
"type": "object",
"properties": {
"strategy": {
"$ref": "#/components/schemas/SamplingStrategy"
"$ref": "#/components/schemas/SamplingStrategy",
"description": "The sampling strategy."
},
"max_tokens": {
"type": "integer",
"default": 0
"default": 0,
"description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length."
},
"repetition_penalty": {
"type": "number",
"default": 1.0
"default": 1.0,
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics."
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
}
},
"additionalProperties": false,
"required": [
"strategy"
],
"title": "SamplingParams"
"title": "SamplingParams",
"description": "Sampling parameters."
},
"SamplingStrategy": {
"oneOf": [

View file

@ -2787,16 +2787,33 @@ components:
properties:
strategy:
$ref: '#/components/schemas/SamplingStrategy'
description: The sampling strategy.
max_tokens:
type: integer
default: 0
description: >-
The maximum number of tokens that can be generated in the completion.
The token count of your prompt plus max_tokens cannot exceed the model's
context length.
repetition_penalty:
type: number
default: 1.0
description: >-
Number between -2.0 and 2.0. Positive values penalize new tokens based
on whether they appear in the text so far, increasing the model's likelihood
to talk about new topics.
stop:
type: array
items:
type: string
description: >-
Up to 4 sequences where the API will stop generating further tokens. The
returned text will not contain the stop sequence.
additionalProperties: false
required:
- strategy
title: SamplingParams
description: Sampling parameters.
SamplingStrategy:
oneOf:
- $ref: '#/components/schemas/GreedySamplingStrategy'

View file

@ -195,11 +195,23 @@ register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
additional_params: Optional[dict] = {}
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):

View file

@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict:
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
if params.stop is not None:
options["stop"] = params.stop
return options

View file

@ -99,6 +99,33 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
assert len(content_str) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:stop_sequence",
],
)
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
if inference_provider_type != "remote::vllm":
pytest.xfail(f"{inference_provider_type} doesn't support 'stop' parameter yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
"stop": ["1963"],
},
)
streamed_content = [chunk.delta for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert "1963" not in content_str
@pytest.mark.parametrize(
"test_case",
[

View file

@ -10,6 +10,11 @@
"expected": "1963"
}
},
"stop_sequence": {
"data": {
"content": "Return the exact same sentence and don't add additional words): Michael Jordan was born in the year of 1963"
}
},
"streaming": {
"data": {
"content": "Roses are red,"