From 58f16bec8b333e25263c3e04efdd3ab5e1889872 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 13 Feb 2025 10:37:25 -0500 Subject: [PATCH] Fix logprobs support in remote-vllm provider The remote-vllm provider was not passing logprobs options from CompletionRequest or ChatCompletionRequests through to the OpenAI client parameters. I manually verified this, as well as observed this provider failing `TestInference::test_completion_logprobs`. This fixes that by passing the `logprobs.top_k` value through to the parameters we pass into the OpenAI client. Additionally, this fixes a bug in `test_text_inference.py` where it mistakenly assumed chunk.delta were of type `ContentDelta` for completion requests. The deltas are of type `ContentDelta` for chat completion requests, but for basic completion requests the deltas are of type string. This test was likely failing for other providers that did properly support logprobs because of this latter issue in the test, which was hit while fixing the above issue with the remote-vllm provider. Fixes #1073 Signed-off-by: Ben Browning --- llama_stack/providers/remote/inference/vllm/vllm.py | 3 +++ llama_stack/providers/tests/inference/test_text_inference.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 02594891b..3574768b5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -345,6 +345,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unknown response format {fmt.type}") + if request.logprobs and request.logprobs.top_k: + input_dict["logprobs"] = request.logprobs.top_k + return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99f968cbc..6a7259123 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -175,7 +175,7 @@ class TestInference: 1 <= len(chunks) <= 6 ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason for chunk in chunks: - if chunk.delta.type == "text" and chunk.delta.text: # if there's a token, we expect logprobs + if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs) else: # no token, no logprobs