diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 04a3dca9b..9f957e867 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -7709,12 +7709,36 @@ "user": { "type": "string", "description": "(Optional) The user to use." + }, + "kwargs": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ], + "description": "(Optional) Additional provider-specific parameters to pass through as extra_body (e.g., chat_template_kwargs for vLLM)." } }, "additionalProperties": false, "required": [ "model", - "messages" + "messages", + "kwargs" ], "title": "OpenaiChatCompletionRequest" }, diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 1a215b877..f56e5303e 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -5666,10 +5666,22 @@ components: user: type: string description: (Optional) The user to use. + kwargs: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Additional provider-specific parameters to pass through as + extra_body (e.g., chat_template_kwargs for vLLM). additionalProperties: false required: - model - messages + - kwargs title: OpenaiChatCompletionRequest OpenAIChatCompletion: type: object diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 9cd526176..9e0dcec59 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -5205,12 +5205,36 @@ "user": { "type": "string", "description": "(Optional) The user to use." + }, + "kwargs": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ], + "description": "(Optional) Additional provider-specific parameters to pass through as extra_body (e.g., chat_template_kwargs for vLLM)." } }, "additionalProperties": false, "required": [ "model", - "messages" + "messages", + "kwargs" ], "title": "OpenaiChatCompletionRequest" }, diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 66ce8e38a..2635aa099 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -3915,10 +3915,22 @@ components: user: type: string description: (Optional) The user to use. + kwargs: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Additional provider-specific parameters to pass through as + extra_body (e.g., chat_template_kwargs for vLLM). additionalProperties: false required: - model - messages + - kwargs title: OpenaiChatCompletionRequest OpenAIChatCompletion: type: object diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 3478d3338..218c1aa1c 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -7214,12 +7214,36 @@ "user": { "type": "string", "description": "(Optional) The user to use." + }, + "kwargs": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ], + "description": "(Optional) Additional provider-specific parameters to pass through as extra_body (e.g., chat_template_kwargs for vLLM)." } }, "additionalProperties": false, "required": [ "model", - "messages" + "messages", + "kwargs" ], "title": "OpenaiChatCompletionRequest" }, diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 6c04542bf..f45535f0a 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -5360,10 +5360,22 @@ components: user: type: string description: (Optional) The user to use. + kwargs: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Additional provider-specific parameters to pass through as + extra_body (e.g., chat_template_kwargs for vLLM). additionalProperties: false required: - model - messages + - kwargs title: OpenaiChatCompletionRequest OpenAIChatCompletion: type: object diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 62a988ea6..d93bb6c45 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1106,6 +1106,7 @@ class InferenceProvider(Protocol): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """Create chat completions. @@ -1134,6 +1135,7 @@ class InferenceProvider(Protocol): :param top_logprobs: (Optional) The top log probabilities to use. :param top_p: (Optional) The top p to use. :param user: (Optional) The user to use. + :param kwargs: (Optional) Additional provider-specific parameters to pass through as extra_body (e.g., chat_template_kwargs for vLLM). :returns: An OpenAIChatCompletion. """ ... diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 847f6a2d2..98cacbb49 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -277,6 +277,7 @@ class InferenceRouter(Inference): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: logger.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", @@ -323,6 +324,7 @@ class InferenceRouter(Inference): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) provider = await self.routing_table.get_provider_impl(model_obj.identifier) if stream: diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 01078760a..8d36a4980 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -153,6 +153,7 @@ class PassthroughInferenceAdapter(Inference): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -181,6 +182,7 @@ class PassthroughInferenceAdapter(Inference): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) return await client.inference.openai_chat_completion(**params) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f752740e5..a00d34f25 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -57,6 +57,7 @@ class RunpodInferenceAdapter(OpenAIMixin): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ): """Override to add RunPod-specific stream_options requirement.""" if stream and not stream_options: @@ -86,4 +87,5 @@ class RunpodInferenceAdapter(OpenAIMixin): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 310eaf7b6..3edf279bd 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -102,6 +102,7 @@ class VLLMInferenceAdapter(OpenAIMixin): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: max_tokens = max_tokens or self.config.max_tokens @@ -136,4 +137,5 @@ class VLLMInferenceAdapter(OpenAIMixin): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index cba7508a2..eac611c88 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -313,6 +313,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """ Direct OpenAI chat completion API call. @@ -361,7 +362,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): user=user, ) - resp = await self.client.chat.completions.create(**params) + # Pass any additional provider-specific parameters as extra_body + extra_body = kwargs if kwargs else {} + + resp = await self.client.chat.completions.create(**params, extra_body=extra_body) return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6d6bb20d5..2197d25a8 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -186,3 +186,47 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" + + +async def test_extra_body_forwarding(vllm_inference_adapter): + """ + Test that extra_body parameters (e.g., chat_template_kwargs) are correctly + forwarded to the underlying OpenAI client. + """ + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=OpenAIChatCompletion( + id="chatcmpl-abc123", + created=1, + model="mock-model", + choices=[ + OpenAIChoice( + message=OpenAIAssistantMessageParam( + content="test response", + ), + finish_reason="stop", + index=0, + ) + ], + ) + ) + mock_client_property.return_value = mock_client + + # Test with chat_template_kwargs for Granite thinking mode + await vllm_inference_adapter.openai_chat_completion( + "mock-model", + messages=[], + stream=False, + chat_template_kwargs={"thinking": True}, + ) + + # Verify that the client was called with extra_body containing chat_template_kwargs + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert "extra_body" in call_kwargs + assert "chat_template_kwargs" in call_kwargs["extra_body"] + assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}