diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index cac310613..79f92adce 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -446,7 +446,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model_obj = await self._get_model(model) extra_body: Dict[str, Any] = {} - if prompt_logprobs: + if prompt_logprobs is not None and prompt_logprobs >= 0: extra_body["prompt_logprobs"] = prompt_logprobs if guided_choice: extra_body["guided_choice"] = guided_choice diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 48c828260..d94390b8f 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -97,7 +97,14 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case): assert len(content_str) > 10 -def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id): +@pytest.mark.parametrize( + "prompt_logprobs", + [ + 1, + 0, + ], +) +def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs): skip_if_provider_isnt_vllm(client_with_models, text_model_id) prompt = "Hello, world!" @@ -106,7 +113,7 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te prompt=prompt, stream=False, extra_body={ - "prompt_logprobs": 1, + "prompt_logprobs": prompt_logprobs, }, ) assert len(response.choices) > 0