mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-30 03:44:20 +00:00
add suffix to openai.completions
This commit is contained in:
parent
e2e15ebb6c
commit
1cfb5b1205
15 changed files with 101 additions and 3 deletions
4
docs/_static/llama-stack-spec.html
vendored
4
docs/_static/llama-stack-spec.html
vendored
|
@ -12404,6 +12404,10 @@
|
||||||
},
|
},
|
||||||
"prompt_logprobs": {
|
"prompt_logprobs": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"suffix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The suffix that should be appended to the completion."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
4
docs/_static/llama-stack-spec.yaml
vendored
4
docs/_static/llama-stack-spec.yaml
vendored
|
@ -8673,6 +8673,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
prompt_logprobs:
|
prompt_logprobs:
|
||||||
type: integer
|
type: integer
|
||||||
|
suffix:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) The suffix that should be appended to the completion.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model
|
||||||
|
|
|
@ -1038,6 +1038,8 @@ class InferenceProvider(Protocol):
|
||||||
# vLLM-specific parameters
|
# vLLM-specific parameters
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
# for fill-in-the-middle type completion
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
|
@ -1058,6 +1060,7 @@ class InferenceProvider(Protocol):
|
||||||
:param temperature: (Optional) The temperature to use.
|
:param temperature: (Optional) The temperature to use.
|
||||||
:param top_p: (Optional) The top p to use.
|
:param top_p: (Optional) The top p to use.
|
||||||
:param user: (Optional) The user to use.
|
:param user: (Optional) The user to use.
|
||||||
|
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||||
:returns: An OpenAICompletion.
|
:returns: An OpenAICompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -426,6 +426,7 @@ class InferenceRouter(Inference):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||||
|
@ -456,6 +457,7 @@ class InferenceRouter(Inference):
|
||||||
user=user,
|
user=user,
|
||||||
guided_choice=guided_choice,
|
guided_choice=guided_choice,
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
|
|
@ -318,6 +318,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
|
||||||
|
|
|
@ -316,6 +316,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
provider_model_id = await self._get_provider_model_id(model)
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
|
|
|
@ -409,6 +409,7 @@ class OllamaInferenceAdapter(
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
if not isinstance(prompt, str):
|
if not isinstance(prompt, str):
|
||||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||||
|
@ -432,6 +433,7 @@ class OllamaInferenceAdapter(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
return await self.openai_client.completions.create(**params) # type: ignore
|
return await self.openai_client.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
if guided_choice is not None:
|
if guided_choice is not None:
|
||||||
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||||
|
@ -117,6 +118,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
return await self._openai_client.completions.create(**params)
|
return await self._openai_client.completions.create(**params)
|
||||||
|
|
||||||
|
|
|
@ -242,6 +242,7 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
|
|
@ -299,6 +299,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
|
|
|
@ -559,6 +559,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
self._lazy_initialize_client()
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
|
|
|
@ -292,6 +292,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
|
|
|
@ -336,6 +336,7 @@ class LiteLLMOpenAIMixin(
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
guided_choice: list[str] | None = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
|
|
|
@ -22,9 +22,6 @@ def provider_from_model(client_with_models, model_id):
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
||||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
|
||||||
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
|
|
||||||
|
|
||||||
provider = provider_from_model(client_with_models, model_id)
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"inline::meta-reference",
|
"inline::meta-reference",
|
||||||
|
@ -44,6 +41,23 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
|
||||||
|
# To test `fim` ( fill in the middle ) completion, we need to use a model that supports suffix.
|
||||||
|
# Use this to specifically test this API functionality.
|
||||||
|
|
||||||
|
# pytest -sv --stack-config="inference=ollama" \
|
||||||
|
# tests/integration/inference/test_openai_completion.py \
|
||||||
|
# --text-model qwen2.5-coder:1.5b \
|
||||||
|
# -k test_openai_completion_non_streaming_suffix
|
||||||
|
|
||||||
|
if model_id != "qwen2.5-coder:1.5b":
|
||||||
|
pytest.skip(f"Suffix is not supported for the model: {model_id}.")
|
||||||
|
|
||||||
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
|
if provider.provider_type != "remote::ollama":
|
||||||
|
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
||||||
|
@ -102,6 +116,32 @@ def test_openai_completion_non_streaming(llama_stack_client, client_with_models,
|
||||||
assert len(choice.text) > 10
|
assert len(choice.text) > 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:completion:suffix",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_completion_non_streaming_suffix(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
|
skip_if_model_doesnt_support_suffix(client_with_models, text_model_id)
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
# ollama needs more verbose prompting for some reason here...
|
||||||
|
response = llama_stack_client.completions.create(
|
||||||
|
model=text_model_id,
|
||||||
|
prompt=tc["content"],
|
||||||
|
stream=False,
|
||||||
|
suffix=tc["suffix"],
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
choice = response.choices[0]
|
||||||
|
assert len(choice.text) > 5
|
||||||
|
assert "france" in choice.text.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
@ -197,6 +237,34 @@ def test_openai_chat_completion_non_streaming(compat_client, client_with_models,
|
||||||
assert expected.lower() in message_content
|
assert expected.lower() in message_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:chat_completion:non_streaming_suffix_01",
|
||||||
|
"inference:chat_completion:non_streaming_suffix_02",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_chat_completion_non_streaming_suffix(compat_client, client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
question = tc["question"]
|
||||||
|
expected = tc["expected"]
|
||||||
|
|
||||||
|
response = compat_client.chat.completions.create(
|
||||||
|
model=text_model_id,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
message_content = response.choices[0].message.content.lower().strip()
|
||||||
|
assert len(message_content) > 0
|
||||||
|
assert expected.lower() in message_content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
|
|
@ -4,6 +4,12 @@
|
||||||
"content": "Complete the sentence using one word: Roses are red, violets are "
|
"content": "Complete the sentence using one word: Roses are red, violets are "
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"suffix": {
|
||||||
|
"data": {
|
||||||
|
"content": "The capital of ",
|
||||||
|
"suffix": "is Paris."
|
||||||
|
}
|
||||||
|
},
|
||||||
"non_streaming": {
|
"non_streaming": {
|
||||||
"data": {
|
"data": {
|
||||||
"content": "Micheael Jordan is born in ",
|
"content": "Micheael Jordan is born in ",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue