diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d1455acaa..baa766692 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -654,6 +654,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + extra_body: dict[str, Any] | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: self._lazy_initialize_client() model_obj = await self._get_model(model) @@ -681,6 +682,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): top_logprobs=top_logprobs, top_p=top_p, user=user, + extra_body=extra_body, ) return await self.client.chat.completions.create(**params) # type: ignore diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 05aee5096..17d795e34 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -5,6 +5,8 @@ # the root directory of this source tree. +from unittest.mock import patch + import pytest from openai import OpenAI @@ -209,6 +211,48 @@ def test_openai_completion_guided_choice(llama_stack_client, client_with_models, # Run the chat-completion tests with both the OpenAI client and the LlamaStack client +def test_openai_chat_completion_extra_body(llama_stack_client, client_with_models, text_model_id): + skip_if_provider_isnt_vllm(client_with_models, text_model_id) + + adapter = client_with_models._routing_table.get_provider_impl(text_model_id) + + captured_params = None + + async def mock_create(**params): + nonlocal captured_params + captured_params = params + + from unittest.mock import AsyncMock + + mock_response = AsyncMock() + mock_response.choices = [AsyncMock()] + mock_response.choices[0].message = AsyncMock() + mock_response.choices[0].message.content = "Test response" + return mock_response + + # Need explicit test extra_body is passed + with patch.object(adapter.client.chat.completions, "create", side_effect=mock_create): + response = llama_stack_client.chat.completions.create( + model=text_model_id, + messages=[ + { + "role": "user", + "content": "Hi how are you?", + } + ], + stream=False, + extra_body={"chat_template_kwargs": {"thinking": True}}, + ) + assert captured_params is not None + assert "extra_body" in captured_params + assert captured_params["extra_body"]["chat_template_kwargs"]["thinking"] is True + + assert len(response.choices) > 0 + choice = response.choices[0] + assert choice.message.content is not None + assert len(choice.message.content) > 0 + + @pytest.mark.parametrize( "test_case", [