diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index 2970d7e53..e5288906f 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -278,6 +278,11 @@ def get_endpoint_operations( if param_name == "self" and param_type is inspect.Parameter.empty: continue + # skip **kwargs parameters - they should not appear in OpenAPI spec + # these are used for forwarding arbitrary extra parameters to underlying APIs + if parameter.kind == inspect.Parameter.VAR_KEYWORD: + continue + # check if all parameters have explicit type if parameter.annotation is inspect.Parameter.empty: raise ValidationError( diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 62a988ea6..d93bb6c45 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1106,6 +1106,7 @@ class InferenceProvider(Protocol): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """Create chat completions. @@ -1134,6 +1135,7 @@ class InferenceProvider(Protocol): :param top_logprobs: (Optional) The top log probabilities to use. :param top_p: (Optional) The top p to use. :param user: (Optional) The user to use. + :param kwargs: (Optional) Additional provider-specific parameters to pass through as extra_body (e.g., chat_template_kwargs for vLLM). :returns: An OpenAIChatCompletion. """ ... diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 847f6a2d2..98cacbb49 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -277,6 +277,7 @@ class InferenceRouter(Inference): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: logger.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", @@ -323,6 +324,7 @@ class InferenceRouter(Inference): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) provider = await self.routing_table.get_provider_impl(model_obj.identifier) if stream: diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index fd65fa10d..c23794eb4 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -173,5 +173,6 @@ class MetaReferenceInferenceImpl( top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b984d97bf..7aa880de3 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -124,5 +124,6 @@ class SentenceTransformersInferenceImpl( top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider") diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 9c8a74b47..ee354aaf3 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -186,5 +186,6 @@ class BedrockInferenceAdapter( top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 01078760a..8d36a4980 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -153,6 +153,7 @@ class PassthroughInferenceAdapter(Inference): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -181,6 +182,7 @@ class PassthroughInferenceAdapter(Inference): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) return await client.inference.openai_chat_completion(**params) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f752740e5..a00d34f25 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -57,6 +57,7 @@ class RunpodInferenceAdapter(OpenAIMixin): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ): """Override to add RunPod-specific stream_options requirement.""" if stream and not stream_options: @@ -86,4 +87,5 @@ class RunpodInferenceAdapter(OpenAIMixin): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 310eaf7b6..3edf279bd 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -102,6 +102,7 @@ class VLLMInferenceAdapter(OpenAIMixin): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: max_tokens = max_tokens or self.config.max_tokens @@ -136,4 +137,5 @@ class VLLMInferenceAdapter(OpenAIMixin): top_logprobs=top_logprobs, top_p=top_p, user=user, + **kwargs, ) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6bef97dd5..68373ada9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -299,6 +299,7 @@ class LiteLLMOpenAIMixin( top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: # Add usage tracking for streaming when telemetry is active from llama_stack.providers.utils.telemetry.tracing import get_current_span @@ -335,6 +336,7 @@ class LiteLLMOpenAIMixin( user=user, api_key=self.get_api_key(), api_base=self.api_base, + **kwargs, ) return await litellm.acompletion(**params) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index cba7508a2..eac611c88 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -313,6 +313,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): top_logprobs: int | None = None, top_p: float | None = None, user: str | None = None, + **kwargs: Any, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """ Direct OpenAI chat completion API call. @@ -361,7 +362,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): user=user, ) - resp = await self.client.chat.completions.create(**params) + # Pass any additional provider-specific parameters as extra_body + extra_body = kwargs if kwargs else {} + + resp = await self.client.chat.completions.create(**params, extra_body=extra_body) return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6d6bb20d5..2197d25a8 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -186,3 +186,47 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" + + +async def test_extra_body_forwarding(vllm_inference_adapter): + """ + Test that extra_body parameters (e.g., chat_template_kwargs) are correctly + forwarded to the underlying OpenAI client. + """ + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=OpenAIChatCompletion( + id="chatcmpl-abc123", + created=1, + model="mock-model", + choices=[ + OpenAIChoice( + message=OpenAIAssistantMessageParam( + content="test response", + ), + finish_reason="stop", + index=0, + ) + ], + ) + ) + mock_client_property.return_value = mock_client + + # Test with chat_template_kwargs for Granite thinking mode + await vllm_inference_adapter.openai_chat_completion( + "mock-model", + messages=[], + stream=False, + chat_template_kwargs={"thinking": True}, + ) + + # Verify that the client was called with extra_body containing chat_template_kwargs + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert "extra_body" in call_kwargs + assert "chat_template_kwargs" in call_kwargs["extra_body"] + assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}