diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 02594891b..3574768b5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -345,6 +345,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unknown response format {fmt.type}") + if request.logprobs and request.logprobs.top_k: + input_dict["logprobs"] = request.logprobs.top_k + return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99f968cbc..6a7259123 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -175,7 +175,7 @@ class TestInference: 1 <= len(chunks) <= 6 ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason for chunk in chunks: - if chunk.delta.type == "text" and chunk.delta.text: # if there's a token, we expect logprobs + if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs) else: # no token, no logprobs