From 43fb18928bca40e2a0766570c1363649c5f48abe Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Fri, 3 Oct 2025 17:37:04 +0200 Subject: [PATCH] Fix BadRequestError due to unvalid max_tokens This patch ensures if max tokens is not defined it is set to None. This avoid some providers to fail, as they don't have protection for it being set to 0 Issue: #3666 --- docs/static/deprecated-llama-stack-spec.html | 1 - docs/static/deprecated-llama-stack-spec.yaml | 1 - docs/static/experimental-llama-stack-spec.html | 1 - docs/static/experimental-llama-stack-spec.yaml | 1 - docs/static/stainless-llama-stack-spec.html | 1 - docs/static/stainless-llama-stack-spec.yaml | 1 - llama_stack/apis/inference/inference.py | 2 +- tests/integration/eval/test_eval.py | 2 ++ 8 files changed, 3 insertions(+), 7 deletions(-) diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 7edfe3f5d..74a94a64d 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -4218,7 +4218,6 @@ }, "max_tokens": { "type": "integer", - "default": 0, "description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length." }, "repetition_penalty": { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index ca832d46b..af5b46941 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -3068,7 +3068,6 @@ components: description: The sampling strategy. max_tokens: type: integer - default: 0 description: >- The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's diff --git a/docs/static/experimental-llama-stack-spec.html b/docs/static/experimental-llama-stack-spec.html index a84226c05..3ca170f61 100644 --- a/docs/static/experimental-llama-stack-spec.html +++ b/docs/static/experimental-llama-stack-spec.html @@ -2713,7 +2713,6 @@ }, "max_tokens": { "type": "integer", - "default": 0, "description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length." }, "repetition_penalty": { diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index a08c0cc87..b44af4bca 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -1927,7 +1927,6 @@ components: description: The sampling strategy. max_tokens: type: integer - default: 0 description: >- The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 1ae477e7e..0f3317981 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -14753,7 +14753,6 @@ }, "max_tokens": { "type": "integer", - "default": 0, "description": "The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length." }, "repetition_penalty": { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index cb2584d8a..cf89a334f 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -10909,7 +10909,6 @@ components: description: The sampling strategy. max_tokens: type: integer - default: 0 description: >- The maximum number of tokens that can be generated in the completion. The token count of your prompt plus max_tokens cannot exceed the model's diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 829a94a6a..59a30a5db 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -96,7 +96,7 @@ class SamplingParams(BaseModel): strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) - max_tokens: int | None = 0 + max_tokens: int | None = None repetition_penalty: float | None = 1.0 stop: list[str] | None = None diff --git a/tests/integration/eval/test_eval.py b/tests/integration/eval/test_eval.py index 01581e829..d735bc2a2 100644 --- a/tests/integration/eval/test_eval.py +++ b/tests/integration/eval/test_eval.py @@ -55,6 +55,7 @@ def test_evaluate_rows(llama_stack_client, text_model_id, scoring_fn_id): "model": text_model_id, "sampling_params": { "temperature": 0.0, + "max_tokens": 512, }, }, }, @@ -88,6 +89,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id): "model": text_model_id, "sampling_params": { "temperature": 0.0, + "max_tokens": 512, }, }, },