mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
vllm prompt_logprobs can also be 0
This adjusts the vllm openai_completion endpoint to also pass a value of 0 for prompt_logprobs, instead of only passing values greater than zero to the backend. The existing test_openai_completion_prompt_logprobs was parameterized to test this case as well. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
8d10556ce3
commit
8f5cd49159
2 changed files with 10 additions and 3 deletions
|
@ -446,7 +446,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: Dict[str, Any] = {}
|
||||
if prompt_logprobs:
|
||||
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
||||
if guided_choice:
|
||||
extra_body["guided_choice"] = guided_choice
|
||||
|
|
|
@ -97,7 +97,14 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case):
|
|||
assert len(content_str) > 10
|
||||
|
||||
|
||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id):
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_logprobs",
|
||||
[
|
||||
1,
|
||||
0,
|
||||
],
|
||||
)
|
||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
prompt = "Hello, world!"
|
||||
|
@ -106,7 +113,7 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
|
|||
prompt=prompt,
|
||||
stream=False,
|
||||
extra_body={
|
||||
"prompt_logprobs": 1,
|
||||
"prompt_logprobs": prompt_logprobs,
|
||||
},
|
||||
)
|
||||
assert len(response.choices) > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue