fix: add missing extra_body to client.chat.completions.create() call

- test requires vLLM as provider, current is skipped in GH Action
- test:
>export VLLM_URL="http://localhost:8000"
>pytest tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_extra_body -v --stack-config="inference=remote::vllm"

Signed-off-by: Wen Zhou <wenzhou@redhat.com>
This commit is contained in:
Wen Zhou 2025-07-11 13:02:11 +02:00
parent d880c2df0e
commit ea964a13ec
2 changed files with 46 additions and 0 deletions

View file

@ -654,6 +654,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_logprobs: int | None = None, top_logprobs: int | None = None,
top_p: float | None = None, top_p: float | None = None,
user: str | None = None, user: str | None = None,
extra_body: dict[str, Any] | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
self._lazy_initialize_client() self._lazy_initialize_client()
model_obj = await self._get_model(model) model_obj = await self._get_model(model)
@ -681,6 +682,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
top_p=top_p, top_p=top_p,
user=user, user=user,
extra_body=extra_body,
) )
return await self.client.chat.completions.create(**params) # type: ignore return await self.client.chat.completions.create(**params) # type: ignore

View file

@ -5,6 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
from unittest.mock import patch
import pytest import pytest
from openai import OpenAI from openai import OpenAI
@ -209,6 +211,48 @@ def test_openai_completion_guided_choice(llama_stack_client, client_with_models,
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client # Run the chat-completion tests with both the OpenAI client and the LlamaStack client
def test_openai_chat_completion_extra_body(llama_stack_client, client_with_models, text_model_id):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
adapter = client_with_models._routing_table.get_provider_impl(text_model_id)
captured_params = None
async def mock_create(**params):
nonlocal captured_params
captured_params = params
from unittest.mock import AsyncMock
mock_response = AsyncMock()
mock_response.choices = [AsyncMock()]
mock_response.choices[0].message = AsyncMock()
mock_response.choices[0].message.content = "Test response"
return mock_response
# Need explicit test extra_body is passed
with patch.object(adapter.client.chat.completions, "create", side_effect=mock_create):
response = llama_stack_client.chat.completions.create(
model=text_model_id,
messages=[
{
"role": "user",
"content": "Hi how are you?",
}
],
stream=False,
extra_body={"chat_template_kwargs": {"thinking": True}},
)
assert captured_params is not None
assert "extra_body" in captured_params
assert captured_params["extra_body"]["chat_template_kwargs"]["thinking"] is True
assert len(response.choices) > 0
choice = response.choices[0]
assert choice.message.content is not None
assert len(choice.message.content) > 0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_case", "test_case",
[ [