diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index e26725907..fddce0c57 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -12777,6 +12777,10 @@
},
"prompt_logprobs": {
"type": "integer"
+ },
+ "suffix": {
+ "type": "string",
+ "description": "(Optional) The suffix that should be appended to the completion."
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index c4f356791..49388939f 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -8903,6 +8903,10 @@ components:
type: string
prompt_logprobs:
type: integer
+ suffix:
+ type: string
+ description: >-
+ (Optional) The suffix that should be appended to the completion.
additionalProperties: false
required:
- model
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 74697dd18..c440794f3 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -1038,6 +1038,8 @@ class InferenceProvider(Protocol):
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ # for fill-in-the-middle type completion
+ suffix: str | None = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
@@ -1058,6 +1060,7 @@ class InferenceProvider(Protocol):
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
+ :param suffix: (Optional) The suffix that should be appended to the completion.
:returns: An OpenAICompletion.
"""
...
diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py
index 62d04cdc4..4e0a33b59 100644
--- a/llama_stack/distribution/routers/inference.py
+++ b/llama_stack/distribution/routers/inference.py
@@ -426,6 +426,7 @@ class InferenceRouter(Inference):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@@ -456,6 +457,7 @@ class InferenceRouter(Inference):
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
+ suffix=suffix,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index 75a9e33e2..79b1b5f08 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -318,6 +318,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py
index 4c68322e0..cb6c6e279 100644
--- a/llama_stack/providers/remote/inference/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py
@@ -316,6 +316,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index f49348c27..d51072fbf 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -440,6 +440,7 @@ class OllamaInferenceAdapter(
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
@@ -463,6 +464,7 @@ class OllamaInferenceAdapter(
temperature=temperature,
top_p=top_p,
user=user,
+ suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore
diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py
index 6f3a686a8..ed4ec22aa 100644
--- a/llama_stack/providers/remote/inference/openai/openai.py
+++ b/llama_stack/providers/remote/inference/openai/openai.py
@@ -90,6 +90,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
@@ -117,6 +118,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
temperature=temperature,
top_p=top_p,
user=user,
+ suffix=suffix,
)
return await self._openai_client.completions.create(**params)
diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py
index 6cf4680e2..e9660abb9 100644
--- a/llama_stack/providers/remote/inference/passthrough/passthrough.py
+++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py
@@ -242,6 +242,7 @@ class PassthroughInferenceAdapter(Inference):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 7305a638d..7030a644d 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -299,6 +299,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index d0a822f3c..16d133c81 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -559,6 +559,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py
index 59f5f5562..7cdd06a1f 100644
--- a/llama_stack/providers/remote/inference/watsonx/watsonx.py
+++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py
@@ -292,6 +292,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 13381f3c9..c21f379c9 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -325,6 +325,7 @@ class LiteLLMOpenAIMixin(
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index 5f0f7fa58..ff95b12a7 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -1290,6 +1290,7 @@ class OpenAICompletionToLlamaStackMixin:
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
+ suffix: str | None = None,
) -> OpenAICompletion:
if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py
index 461527d18..3e43af272 100644
--- a/tests/integration/inference/test_openai_completion.py
+++ b/tests/integration/inference/test_openai_completion.py
@@ -22,9 +22,6 @@ def provider_from_model(client_with_models, model_id):
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
- if isinstance(client_with_models, LlamaStackAsLibraryClient):
- pytest.skip("OpenAI completions are not supported when testing with library client yet.")
-
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"inline::meta-reference",
@@ -44,6 +41,23 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
+def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
+ # To test `fim` ( fill in the middle ) completion, we need to use a model that supports suffix.
+ # Use this to specifically test this API functionality.
+
+ # pytest -sv --stack-config="inference=ollama" \
+ # tests/integration/inference/test_openai_completion.py \
+ # --text-model qwen2.5-coder:1.5b \
+ # -k test_openai_completion_non_streaming_suffix
+
+ if model_id != "qwen2.5-coder:1.5b":
+ pytest.skip(f"Suffix is not supported for the model: {model_id}.")
+
+ provider = provider_from_model(client_with_models, model_id)
+ if provider.provider_type != "remote::ollama":
+ pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
+
+
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
@@ -102,6 +116,32 @@ def test_openai_completion_non_streaming(llama_stack_client, client_with_models,
assert len(choice.text) > 10
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ "inference:completion:suffix",
+ ],
+)
+def test_openai_completion_non_streaming_suffix(llama_stack_client, client_with_models, text_model_id, test_case):
+ skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
+ skip_if_model_doesnt_support_suffix(client_with_models, text_model_id)
+ tc = TestCase(test_case)
+
+ # ollama needs more verbose prompting for some reason here...
+ response = llama_stack_client.completions.create(
+ model=text_model_id,
+ prompt=tc["content"],
+ stream=False,
+ suffix=tc["suffix"],
+ max_tokens=10,
+ )
+
+ assert len(response.choices) > 0
+ choice = response.choices[0]
+ assert len(choice.text) > 5
+ assert "france" in choice.text.lower()
+
+
@pytest.mark.parametrize(
"test_case",
[
diff --git a/tests/integration/test_cases/inference/completion.json b/tests/integration/test_cases/inference/completion.json
index 731ceddbc..baaecb375 100644
--- a/tests/integration/test_cases/inference/completion.json
+++ b/tests/integration/test_cases/inference/completion.json
@@ -4,6 +4,12 @@
"content": "Complete the sentence using one word: Roses are red, violets are "
}
},
+ "suffix": {
+ "data": {
+ "content": "The capital of ",
+ "suffix": "is Paris."
+ }
+ },
"non_streaming": {
"data": {
"content": "Micheael Jordan is born in ",