mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Add suffix
to openai_completions (#2449)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 15s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 16s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 14s
Unit Tests / unit-tests (3.10) (push) Failing after 19s
Unit Tests / unit-tests (3.11) (push) Failing after 20s
Unit Tests / unit-tests (3.12) (push) Failing after 18s
Unit Tests / unit-tests (3.13) (push) Failing after 16s
Update ReadTheDocs / update-readthedocs (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 58s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 15s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 16s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 14s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 14s
Unit Tests / unit-tests (3.10) (push) Failing after 19s
Unit Tests / unit-tests (3.11) (push) Failing after 20s
Unit Tests / unit-tests (3.12) (push) Failing after 18s
Unit Tests / unit-tests (3.13) (push) Failing after 16s
Update ReadTheDocs / update-readthedocs (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 58s
For code completion apps need "fill in the middle" capabilities. Added option of `suffix` to `openai_completion` to enable this. Updated ollama provider to showcase the same. ### Test Plan ``` pytest -sv --stack-config="inference=ollama" tests/integration/inference/test_openai_completion.py --text-model qwen2.5-coder:1.5b -k test_openai_completion_non_streaming_suffix ``` ### OpenAI Sample script ``` from openai import OpenAI client = OpenAI(base_url="http://localhost:8321/v1/openai/v1") response = client.completions.create( model="qwen2.5-coder:1.5b", prompt="The capital of ", suffix="is Paris.", max_tokens=10, ) print(response.choices[0].text) ``` ### Output ``` France is ____. To answer this question, we ```
This commit is contained in:
parent
2e8054bede
commit
985d0b156c
16 changed files with 74 additions and 3 deletions
4
docs/_static/llama-stack-spec.html
vendored
4
docs/_static/llama-stack-spec.html
vendored
|
@ -12777,6 +12777,10 @@
|
||||||
},
|
},
|
||||||
"prompt_logprobs": {
|
"prompt_logprobs": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"suffix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The suffix that should be appended to the completion."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
4
docs/_static/llama-stack-spec.yaml
vendored
4
docs/_static/llama-stack-spec.yaml
vendored
|
@ -8903,6 +8903,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
prompt_logprobs:
|
prompt_logprobs:
|
||||||
type: integer
|
type: integer
|
||||||
|
suffix:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) The suffix that should be appended to the completion.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model
|
||||||
|
|
|
@ -1038,6 +1038,8 @@ class InferenceProvider(Protocol):
|
||||||
# vLLM-specific parameters
|
# vLLM-specific parameters
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
# for fill-in-the-middle type completion
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
|
@ -1058,6 +1060,7 @@ class InferenceProvider(Protocol):
|
||||||
:param temperature: (Optional) The temperature to use.
|
:param temperature: (Optional) The temperature to use.
|
||||||
:param top_p: (Optional) The top p to use.
|
:param top_p: (Optional) The top p to use.
|
||||||
:param user: (Optional) The user to use.
|
:param user: (Optional) The user to use.
|
||||||
|
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||||
:returns: An OpenAICompletion.
|
:returns: An OpenAICompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -426,6 +426,7 @@ class InferenceRouter(Inference):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||||
|
@ -456,6 +457,7 @@ class InferenceRouter(Inference):
|
||||||
user=user,
|
user=user,
|
||||||
guided_choice=guided_choice,
|
guided_choice=guided_choice,
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
|
|
@ -318,6 +318,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
|
||||||
|
|
|
@ -316,6 +316,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
provider_model_id = await self._get_provider_model_id(model)
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
|
|
|
@ -440,6 +440,7 @@ class OllamaInferenceAdapter(
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
if not isinstance(prompt, str):
|
if not isinstance(prompt, str):
|
||||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||||
|
@ -463,6 +464,7 @@ class OllamaInferenceAdapter(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
return await self.openai_client.completions.create(**params) # type: ignore
|
return await self.openai_client.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
if guided_choice is not None:
|
if guided_choice is not None:
|
||||||
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||||
|
@ -117,6 +118,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
return await self._openai_client.completions.create(**params)
|
return await self._openai_client.completions.create(**params)
|
||||||
|
|
||||||
|
|
|
@ -242,6 +242,7 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
|
|
@ -299,6 +299,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
|
|
|
@ -559,6 +559,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
self._lazy_initialize_client()
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
|
|
|
@ -292,6 +292,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
|
|
|
@ -325,6 +325,7 @@ class LiteLLMOpenAIMixin(
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
|
|
|
@ -1290,6 +1290,7 @@ class OpenAICompletionToLlamaStackMixin:
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
if stream:
|
if stream:
|
||||||
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
|
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
|
||||||
|
|
|
@ -22,9 +22,6 @@ def provider_from_model(client_with_models, model_id):
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
||||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
|
||||||
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
|
|
||||||
|
|
||||||
provider = provider_from_model(client_with_models, model_id)
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"inline::meta-reference",
|
"inline::meta-reference",
|
||||||
|
@ -44,6 +41,23 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
|
||||||
|
# To test `fim` ( fill in the middle ) completion, we need to use a model that supports suffix.
|
||||||
|
# Use this to specifically test this API functionality.
|
||||||
|
|
||||||
|
# pytest -sv --stack-config="inference=ollama" \
|
||||||
|
# tests/integration/inference/test_openai_completion.py \
|
||||||
|
# --text-model qwen2.5-coder:1.5b \
|
||||||
|
# -k test_openai_completion_non_streaming_suffix
|
||||||
|
|
||||||
|
if model_id != "qwen2.5-coder:1.5b":
|
||||||
|
pytest.skip(f"Suffix is not supported for the model: {model_id}.")
|
||||||
|
|
||||||
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
|
if provider.provider_type != "remote::ollama":
|
||||||
|
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
||||||
|
@ -102,6 +116,32 @@ def test_openai_completion_non_streaming(llama_stack_client, client_with_models,
|
||||||
assert len(choice.text) > 10
|
assert len(choice.text) > 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:completion:suffix",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_completion_non_streaming_suffix(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
|
skip_if_model_doesnt_support_suffix(client_with_models, text_model_id)
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
# ollama needs more verbose prompting for some reason here...
|
||||||
|
response = llama_stack_client.completions.create(
|
||||||
|
model=text_model_id,
|
||||||
|
prompt=tc["content"],
|
||||||
|
stream=False,
|
||||||
|
suffix=tc["suffix"],
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
choice = response.choices[0]
|
||||||
|
assert len(choice.text) > 5
|
||||||
|
assert "france" in choice.text.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
|
|
@ -4,6 +4,12 @@
|
||||||
"content": "Complete the sentence using one word: Roses are red, violets are "
|
"content": "Complete the sentence using one word: Roses are red, violets are "
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"suffix": {
|
||||||
|
"data": {
|
||||||
|
"content": "The capital of ",
|
||||||
|
"suffix": "is Paris."
|
||||||
|
}
|
||||||
|
},
|
||||||
"non_streaming": {
|
"non_streaming": {
|
||||||
"data": {
|
"data": {
|
||||||
"content": "Micheael Jordan is born in ",
|
"content": "Micheael Jordan is born in ",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue