diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index fb3e78afc..4570eaa71 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1058,8 +1058,6 @@ class OpenAICompletionRequest(BaseModel): :param top_p: (Optional) The top p to use. :param user: (Optional) The user to use. :param suffix: (Optional) The suffix that should be appended to the completion. - :param guided_choice: (Optional) vLLM-specific parameter for guided generation with a list of choices. - :param prompt_logprobs: (Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens. """ model_config = ConfigDict(extra="allow") @@ -1082,12 +1080,6 @@ class OpenAICompletionRequest(BaseModel): temperature: float | None = None top_p: float | None = None user: str | None = None - - # vLLM-specific parameters (documented here but also allowed via extra fields) - guided_choice: list[str] | None = None - prompt_logprobs: int | None = None - - # for fill-in-the-middle type completion suffix: str | None = None diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index eed078a0e..0f94a2bfe 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -230,6 +230,9 @@ class LiteLLMOpenAIMixin( ) -> OpenAICompletion: model_obj = await self.model_store.get_model(params.model) + # Extract extra fields + extra_body = dict(params.__pydantic_extra__ or {}) + request_params = await prepare_openai_completion_params( model=self.get_litellm_model_name(model_obj.provider_resource_id), prompt=params.prompt, @@ -248,11 +251,10 @@ class LiteLLMOpenAIMixin( temperature=params.temperature, top_p=params.top_p, user=params.user, - guided_choice=params.guided_choice, - prompt_logprobs=params.prompt_logprobs, suffix=params.suffix, api_key=self.get_api_key(), api_base=self.api_base, + **extra_body, ) return await litellm.atext_completion(**request_params) @@ -272,6 +274,9 @@ class LiteLLMOpenAIMixin( model_obj = await self.model_store.get_model(params.model) + # Extract extra fields + extra_body = dict(params.__pydantic_extra__ or {}) + request_params = await prepare_openai_completion_params( model=self.get_litellm_model_name(model_obj.provider_resource_id), messages=params.messages, @@ -298,6 +303,7 @@ class LiteLLMOpenAIMixin( user=params.user, api_key=self.get_api_key(), api_base=self.api_base, + **extra_body, ) return await litellm.acompletion(**request_params) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 502bc207b..db42eb10a 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -228,15 +228,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ Direct OpenAI completion API call. """ - # Handle parameters that are not supported by OpenAI API, but may be by the provider - # prompt_logprobs is supported by vLLM - # guided_choice is supported by vLLM - # TODO: test coverage - extra_body: dict[str, Any] = {} - if params.prompt_logprobs is not None and params.prompt_logprobs >= 0: - extra_body["prompt_logprobs"] = params.prompt_logprobs - if params.guided_choice: - extra_body["guided_choice"] = params.guided_choice # TODO: fix openai_completion to return type compatible with OpenAI's API response completion_kwargs = await prepare_openai_completion_params( @@ -259,7 +250,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): user=params.user, suffix=params.suffix, ) - resp = await self.client.completions.create(**completion_kwargs, extra_body=extra_body) + if extra_body := dict(params.__pydantic_extra__ or {}): + completion_kwargs["extra_body"] = extra_body + resp = await self.client.completions.create(**completion_kwargs) return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return] @@ -316,6 +309,8 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): user=params.user, ) + if extra_body := dict(params.__pydantic_extra__ or {}): + request_params["extra_body"] = extra_body resp = await self.client.chat.completions.create(**request_params) return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return] diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 569fb5031..f5f11e901 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -15,6 +15,9 @@ from llama_stack.apis.inference import ( OpenAIChatCompletion, OpenAIChatCompletionRequest, OpenAIChoice, + OpenAICompletion, + OpenAICompletionChoice, + OpenAICompletionRequest, ToolChoice, ) from llama_stack.apis.models import Model @@ -191,3 +194,94 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" + + +async def test_extra_body_forwarding(vllm_inference_adapter): + """ + Test that extra_body parameters (e.g., chat_template_kwargs) are correctly + forwarded to the underlying OpenAI client. + """ + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=OpenAIChatCompletion( + id="chatcmpl-abc123", + created=1, + model="mock-model", + choices=[ + OpenAIChoice( + message=OpenAIAssistantMessageParam( + content="test response", + ), + finish_reason="stop", + index=0, + ) + ], + ) + ) + mock_client_property.return_value = mock_client + + # Test with chat_template_kwargs for Granite thinking mode + params = OpenAIChatCompletionRequest( + model="mock-model", + messages=[{"role": "user", "content": "test"}], + stream=False, + chat_template_kwargs={"thinking": True}, + ) + await vllm_inference_adapter.openai_chat_completion(params) + + # Verify that the client was called with extra_body containing chat_template_kwargs + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert "extra_body" in call_kwargs + assert "chat_template_kwargs" in call_kwargs["extra_body"] + assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True} + + +async def test_vllm_completion_extra_body(vllm_inference_adapter): + """ + Test that vLLM-specific guided_choice parameter is correctly forwarded + via extra_body to the underlying OpenAI client. + """ + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client.completions.create = AsyncMock( + return_value=OpenAICompletion( + id="cmpl-abc123", + created=1, + model="mock-model", + choices=[ + OpenAICompletionChoice( + text="joy", + finish_reason="stop", + index=0, + ) + ], + ) + ) + mock_client_property.return_value = mock_client + + # Test with guided_choice as extra field + params = OpenAICompletionRequest( + model="mock-model", + prompt="I am feeling happy", + stream=False, + guided_choice=["joy", "sadness"], + prompt_logprobs=5, + ) + await vllm_inference_adapter.openai_completion(params) + + # Verify that the client was called with extra_body containing guided_choice + mock_client.completions.create.assert_called_once() + call_kwargs = mock_client.completions.create.call_args.kwargs + assert "extra_body" in call_kwargs + assert "guided_choice" in call_kwargs["extra_body"] + assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"] + assert "prompt_logprobs" in call_kwargs["extra_body"] + assert call_kwargs["extra_body"]["prompt_logprobs"] == 5