diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d93bb6c45..bac79aa30 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1052,6 +1052,7 @@ class InferenceProvider(Protocol): prompt_logprobs: int | None = None, # for fill-in-the-middle type completion suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: """Create completion. @@ -1075,6 +1076,7 @@ class InferenceProvider(Protocol): :param top_p: (Optional) The top p to use. :param user: (Optional) The user to use. :param suffix: (Optional) The suffix that should be appended to the completion. + :param kwargs: (Optional) Additional provider-specific parameters to pass through as extra_body. :returns: An OpenAICompletion. """ ... diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 98cacbb49..2afb67d90 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -201,6 +201,7 @@ class InferenceRouter(Inference): guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", @@ -227,6 +228,7 @@ class InferenceRouter(Inference): guided_choice=guided_choice, prompt_logprobs=prompt_logprobs, suffix=suffix, + **kwargs, ) provider = await self.routing_table.get_provider_impl(model_obj.identifier) if stream: diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 7aa880de3..c5450448e 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -96,6 +96,7 @@ class SentenceTransformersInferenceImpl( prompt_logprobs: int | None = None, # for fill-in-the-middle type completion suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index ee354aaf3..ccd18a3e8 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -158,6 +158,7 @@ class BedrockInferenceAdapter( prompt_logprobs: int | None = None, # for fill-in-the-middle type completion suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 200b36171..c6838271f 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -63,5 +63,6 @@ class DatabricksInferenceAdapter(OpenAIMixin): guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 165992c16..37010114b 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -54,6 +54,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin): guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 8d36a4980..372d2d663 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -100,6 +100,7 @@ class PassthroughInferenceAdapter(Inference): guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -124,6 +125,7 @@ class PassthroughInferenceAdapter(Inference): user=user, guided_choice=guided_choice, prompt_logprobs=prompt_logprobs, + **kwargs, ) return await client.inference.openai_completion(**params) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 68373ada9..22bfd36c4 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -247,6 +247,7 @@ class LiteLLMOpenAIMixin( guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( @@ -271,6 +272,7 @@ class LiteLLMOpenAIMixin( prompt_logprobs=prompt_logprobs, api_key=self.get_api_key(), api_base=self.api_base, + **kwargs, ) return await litellm.atext_completion(**params) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index eac611c88..7c08e668d 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -247,6 +247,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, suffix: str | None = None, + **kwargs: Any, ) -> OpenAICompletion: """ Direct OpenAI completion API call. @@ -261,6 +262,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): if guided_choice: extra_body["guided_choice"] = guided_choice + # Merge any additional kwargs into extra_body + extra_body.update(kwargs) + # TODO: fix openai_completion to return type compatible with OpenAI's API response resp = await self.client.completions.create( **await prepare_openai_completion_params(