mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
add suffix to openai.completions
This commit is contained in:
parent
e2e15ebb6c
commit
1cfb5b1205
15 changed files with 101 additions and 3 deletions
|
@ -22,9 +22,6 @@ def provider_from_model(client_with_models, model_id):
|
|||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
"inline::meta-reference",
|
||||
|
@ -44,6 +41,23 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
|
||||
# To test `fim` ( fill in the middle ) completion, we need to use a model that supports suffix.
|
||||
# Use this to specifically test this API functionality.
|
||||
|
||||
# pytest -sv --stack-config="inference=ollama" \
|
||||
# tests/integration/inference/test_openai_completion.py \
|
||||
# --text-model qwen2.5-coder:1.5b \
|
||||
# -k test_openai_completion_non_streaming_suffix
|
||||
|
||||
if model_id != "qwen2.5-coder:1.5b":
|
||||
pytest.skip(f"Suffix is not supported for the model: {model_id}.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type != "remote::ollama":
|
||||
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
||||
|
@ -102,6 +116,32 @@ def test_openai_completion_non_streaming(llama_stack_client, client_with_models,
|
|||
assert len(choice.text) > 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:suffix",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_non_streaming_suffix(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
skip_if_model_doesnt_support_suffix(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=tc["content"],
|
||||
stream=False,
|
||||
suffix=tc["suffix"],
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
assert len(response.choices) > 0
|
||||
choice = response.choices[0]
|
||||
assert len(choice.text) > 5
|
||||
assert "france" in choice.text.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
@ -197,6 +237,34 @@ def test_openai_chat_completion_non_streaming(compat_client, client_with_models,
|
|||
assert expected.lower() in message_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:non_streaming_suffix_01",
|
||||
"inference:chat_completion:non_streaming_suffix_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_non_streaming_suffix(compat_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = compat_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
}
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
message_content = response.choices[0].message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
assert expected.lower() in message_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue