feat: Add suffix to openai_completions (#2449)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 15s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 16s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 14s
Unit Tests / unit-tests (3.10) (push) Failing after 19s
Unit Tests / unit-tests (3.11) (push) Failing after 20s
Unit Tests / unit-tests (3.12) (push) Failing after 18s
Unit Tests / unit-tests (3.13) (push) Failing after 16s
Update ReadTheDocs / update-readthedocs (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 58s

For code completion apps need "fill in the middle" capabilities. 
Added option of `suffix` to `openai_completion` to enable this. 
Updated ollama provider to showcase the same. 

### Test Plan 
```
pytest -sv --stack-config="inference=ollama"  tests/integration/inference/test_openai_completion.py --text-model qwen2.5-coder:1.5b -k test_openai_completion_non_streaming_suffix
```

### OpenAI Sample script
```
from openai import OpenAI

client = OpenAI(base_url="http://localhost:8321/v1/openai/v1")

response = client.completions.create(
    model="qwen2.5-coder:1.5b",
    prompt="The capital of ",
    suffix="is Paris.",
    max_tokens=10,
)

print(response.choices[0].text)
``` 
### Output
```
France is ____.

To answer this question, we 
```
This commit is contained in:
Hardik Shah 2025-06-13 16:06:06 -07:00 committed by GitHub
parent 2e8054bede
commit 985d0b156c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 74 additions and 3 deletions

View file

@ -12777,6 +12777,10 @@
}, },
"prompt_logprobs": { "prompt_logprobs": {
"type": "integer" "type": "integer"
},
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -8903,6 +8903,10 @@ components:
type: string type: string
prompt_logprobs: prompt_logprobs:
type: integer type: integer
suffix:
type: string
description: >-
(Optional) The suffix that should be appended to the completion.
additionalProperties: false additionalProperties: false
required: required:
- model - model

View file

@ -1038,6 +1038,8 @@ class InferenceProvider(Protocol):
# vLLM-specific parameters # vLLM-specific parameters
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model. """Generate an OpenAI-compatible completion for the given prompt using the specified model.
@ -1058,6 +1060,7 @@ class InferenceProvider(Protocol):
:param temperature: (Optional) The temperature to use. :param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use. :param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use. :param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:returns: An OpenAICompletion. :returns: An OpenAICompletion.
""" """
... ...

View file

@ -426,6 +426,7 @@ class InferenceRouter(Inference):
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
logger.debug( logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@ -456,6 +457,7 @@ class InferenceRouter(Inference):
user=user, user=user,
guided_choice=guided_choice, guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
suffix=suffix,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = self.routing_table.get_provider_impl(model_obj.identifier)

View file

@ -318,6 +318,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)

View file

@ -316,6 +316,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model) provider_model_id = await self._get_provider_model_id(model)

View file

@ -440,6 +440,7 @@ class OllamaInferenceAdapter(
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
if not isinstance(prompt, str): if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion") raise ValueError("Ollama does not support non-string prompts for completion")
@ -463,6 +464,7 @@ class OllamaInferenceAdapter(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
user=user, user=user,
suffix=suffix,
) )
return await self.openai_client.completions.create(**params) # type: ignore return await self.openai_client.completions.create(**params) # type: ignore

View file

@ -90,6 +90,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
if guided_choice is not None: if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.") logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
@ -117,6 +118,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
user=user, user=user,
suffix=suffix,
) )
return await self._openai_client.completions.create(**params) return await self._openai_client.completions.create(**params)

View file

@ -242,6 +242,7 @@ class PassthroughInferenceAdapter(Inference):
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
client = self._get_client() client = self._get_client()
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)

View file

@ -299,6 +299,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(

View file

@ -559,6 +559,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
self._lazy_initialize_client() self._lazy_initialize_client()
model_obj = await self._get_model(model) model_obj = await self._get_model(model)

View file

@ -292,6 +292,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(

View file

@ -325,6 +325,7 @@ class LiteLLMOpenAIMixin(
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(

View file

@ -1290,6 +1290,7 @@ class OpenAICompletionToLlamaStackMixin:
user: str | None = None, user: str | None = None,
guided_choice: list[str] | None = None, guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
if stream: if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")

View file

@ -22,9 +22,6 @@ def provider_from_model(client_with_models, model_id):
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
provider = provider_from_model(client_with_models, model_id) provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in ( if provider.provider_type in (
"inline::meta-reference", "inline::meta-reference",
@ -44,6 +41,23 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
# To test `fim` ( fill in the middle ) completion, we need to use a model that supports suffix.
# Use this to specifically test this API functionality.
# pytest -sv --stack-config="inference=ollama" \
# tests/integration/inference/test_openai_completion.py \
# --text-model qwen2.5-coder:1.5b \
# -k test_openai_completion_non_streaming_suffix
if model_id != "qwen2.5-coder:1.5b":
pytest.skip(f"Suffix is not supported for the model: {model_id}.")
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type != "remote::ollama":
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient): if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.") pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
@ -102,6 +116,32 @@ def test_openai_completion_non_streaming(llama_stack_client, client_with_models,
assert len(choice.text) > 10 assert len(choice.text) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:suffix",
],
)
def test_openai_completion_non_streaming_suffix(llama_stack_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
skip_if_model_doesnt_support_suffix(client_with_models, text_model_id)
tc = TestCase(test_case)
# ollama needs more verbose prompting for some reason here...
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=tc["content"],
stream=False,
suffix=tc["suffix"],
max_tokens=10,
)
assert len(response.choices) > 0
choice = response.choices[0]
assert len(choice.text) > 5
assert "france" in choice.text.lower()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_case", "test_case",
[ [

View file

@ -4,6 +4,12 @@
"content": "Complete the sentence using one word: Roses are red, violets are " "content": "Complete the sentence using one word: Roses are red, violets are "
} }
}, },
"suffix": {
"data": {
"content": "The capital of ",
"suffix": "is Paris."
}
},
"non_streaming": { "non_streaming": {
"data": { "data": {
"content": "Micheael Jordan is born in ", "content": "Micheael Jordan is born in ",