From 985d0b156ccde9a8e344279ffe4d5c6f81d341fc Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Fri, 13 Jun 2025 16:06:06 -0700 Subject: [PATCH] feat: Add `suffix` to openai_completions (#2449) For code completion apps need "fill in the middle" capabilities. Added option of `suffix` to `openai_completion` to enable this. Updated ollama provider to showcase the same. ### Test Plan ``` pytest -sv --stack-config="inference=ollama" tests/integration/inference/test_openai_completion.py --text-model qwen2.5-coder:1.5b -k test_openai_completion_non_streaming_suffix ``` ### OpenAI Sample script ``` from openai import OpenAI client = OpenAI(base_url="http://localhost:8321/v1/openai/v1") response = client.completions.create( model="qwen2.5-coder:1.5b", prompt="The capital of ", suffix="is Paris.", max_tokens=10, ) print(response.choices[0].text) ``` ### Output ``` France is ____. To answer this question, we ``` --- docs/_static/llama-stack-spec.html | 4 ++ docs/_static/llama-stack-spec.yaml | 4 ++ llama_stack/apis/inference/inference.py | 3 ++ llama_stack/distribution/routers/inference.py | 2 + .../remote/inference/fireworks/fireworks.py | 1 + .../remote/inference/nvidia/nvidia.py | 1 + .../remote/inference/ollama/ollama.py | 2 + .../remote/inference/openai/openai.py | 2 + .../inference/passthrough/passthrough.py | 1 + .../remote/inference/together/together.py | 1 + .../providers/remote/inference/vllm/vllm.py | 1 + .../remote/inference/watsonx/watsonx.py | 1 + .../utils/inference/litellm_openai_mixin.py | 1 + .../utils/inference/openai_compat.py | 1 + .../inference/test_openai_completion.py | 46 +++++++++++++++++-- .../test_cases/inference/completion.json | 6 +++ 16 files changed, 74 insertions(+), 3 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index e26725907..fddce0c57 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -12777,6 +12777,10 @@ }, "prompt_logprobs": { "type": "integer" + }, + "suffix": { + "type": "string", + "description": "(Optional) The suffix that should be appended to the completion." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c4f356791..49388939f 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8903,6 +8903,10 @@ components: type: string prompt_logprobs: type: integer + suffix: + type: string + description: >- + (Optional) The suffix that should be appended to the completion. additionalProperties: false required: - model diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 74697dd18..c440794f3 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1038,6 +1038,8 @@ class InferenceProvider(Protocol): # vLLM-specific parameters guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + # for fill-in-the-middle type completion + suffix: str | None = None, ) -> OpenAICompletion: """Generate an OpenAI-compatible completion for the given prompt using the specified model. @@ -1058,6 +1060,7 @@ class InferenceProvider(Protocol): :param temperature: (Optional) The temperature to use. :param top_p: (Optional) The top p to use. :param user: (Optional) The user to use. + :param suffix: (Optional) The suffix that should be appended to the completion. :returns: An OpenAICompletion. """ ... diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index 62d04cdc4..4e0a33b59 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -426,6 +426,7 @@ class InferenceRouter(Inference): user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", @@ -456,6 +457,7 @@ class InferenceRouter(Inference): user=user, guided_choice=guided_choice, prompt_logprobs=prompt_logprobs, + suffix=suffix, ) provider = self.routing_table.get_provider_impl(model_obj.identifier) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 75a9e33e2..79b1b5f08 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -318,6 +318,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 4c68322e0..cb6c6e279 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -316,6 +316,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: provider_model_id = await self._get_provider_model_id(model) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f49348c27..d51072fbf 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -440,6 +440,7 @@ class OllamaInferenceAdapter( user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: if not isinstance(prompt, str): raise ValueError("Ollama does not support non-string prompts for completion") @@ -463,6 +464,7 @@ class OllamaInferenceAdapter( temperature=temperature, top_p=top_p, user=user, + suffix=suffix, ) return await self.openai_client.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 6f3a686a8..ed4ec22aa 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -90,6 +90,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: if guided_choice is not None: logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.") @@ -117,6 +118,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): temperature=temperature, top_p=top_p, user=user, + suffix=suffix, ) return await self._openai_client.completions.create(**params) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 6cf4680e2..e9660abb9 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -242,6 +242,7 @@ class PassthroughInferenceAdapter(Inference): user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: client = self._get_client() model_obj = await self.model_store.get_model(model) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 7305a638d..7030a644d 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -299,6 +299,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d0a822f3c..16d133c81 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -559,6 +559,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: self._lazy_initialize_client() model_obj = await self._get_model(model) diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 59f5f5562..7cdd06a1f 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -292,6 +292,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 13381f3c9..c21f379c9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -325,6 +325,7 @@ class LiteLLMOpenAIMixin( user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 5f0f7fa58..ff95b12a7 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1290,6 +1290,7 @@ class OpenAICompletionToLlamaStackMixin: user: str | None = None, guided_choice: list[str] | None = None, prompt_logprobs: int | None = None, + suffix: str | None = None, ) -> OpenAICompletion: if stream: raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 461527d18..3e43af272 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -22,9 +22,6 @@ def provider_from_model(client_with_models, model_id): def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI completions are not supported when testing with library client yet.") - provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( "inline::meta-reference", @@ -44,6 +41,23 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") +def skip_if_model_doesnt_support_suffix(client_with_models, model_id): + # To test `fim` ( fill in the middle ) completion, we need to use a model that supports suffix. + # Use this to specifically test this API functionality. + + # pytest -sv --stack-config="inference=ollama" \ + # tests/integration/inference/test_openai_completion.py \ + # --text-model qwen2.5-coder:1.5b \ + # -k test_openai_completion_non_streaming_suffix + + if model_id != "qwen2.5-coder:1.5b": + pytest.skip(f"Suffix is not supported for the model: {model_id}.") + + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type != "remote::ollama": + pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.") + + def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): if isinstance(client_with_models, LlamaStackAsLibraryClient): pytest.skip("OpenAI chat completions are not supported when testing with library client yet.") @@ -102,6 +116,32 @@ def test_openai_completion_non_streaming(llama_stack_client, client_with_models, assert len(choice.text) > 10 +@pytest.mark.parametrize( + "test_case", + [ + "inference:completion:suffix", + ], +) +def test_openai_completion_non_streaming_suffix(llama_stack_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) + skip_if_model_doesnt_support_suffix(client_with_models, text_model_id) + tc = TestCase(test_case) + + # ollama needs more verbose prompting for some reason here... + response = llama_stack_client.completions.create( + model=text_model_id, + prompt=tc["content"], + stream=False, + suffix=tc["suffix"], + max_tokens=10, + ) + + assert len(response.choices) > 0 + choice = response.choices[0] + assert len(choice.text) > 5 + assert "france" in choice.text.lower() + + @pytest.mark.parametrize( "test_case", [ diff --git a/tests/integration/test_cases/inference/completion.json b/tests/integration/test_cases/inference/completion.json index 731ceddbc..baaecb375 100644 --- a/tests/integration/test_cases/inference/completion.json +++ b/tests/integration/test_cases/inference/completion.json @@ -4,6 +4,12 @@ "content": "Complete the sentence using one word: Roses are red, violets are " } }, + "suffix": { + "data": { + "content": "The capital of ", + "suffix": "is Paris." + } + }, "non_streaming": { "data": { "content": "Micheael Jordan is born in ",