forked from phoenix-oss/llama-stack-mirror
feat: completing text /chat-completion and /completion tests (#1223)
# What does this PR do? The goal is to have a fairly complete set of provider and e2e tests for /chat-completion and /completion. This is the current list, ``` grep -oE "def test_[a-zA-Z_+]*" llama_stack/providers/tests/inference/test_text_inference.py | cut -d' ' -f2 ``` - test_model_list - test_text_completion_non_streaming - test_text_completion_streaming - test_text_completion_logprobs_non_streaming - test_text_completion_logprobs_streaming - test_text_completion_structured_output - test_text_chat_completion_non_streaming - test_text_chat_completion_structured_output - test_text_chat_completion_streaming - test_text_chat_completion_with_tool_calling - test_text_chat_completion_with_tool_calling_streaming ``` grep -oE "def test_[a-zA-Z_+]*" tests/client-sdk/inference/test_text_inference.py | cut -d' ' -f2 ``` - test_text_completion_non_streaming - test_text_completion_streaming - test_text_completion_log_probs_non_streaming - test_text_completion_log_probs_streaming - test_text_completion_structured_output - test_text_chat_completion_non_streaming - test_text_chat_completion_streaming - test_text_chat_completion_with_tool_calling_and_non_streaming - test_text_chat_completion_with_tool_calling_and_streaming - test_text_chat_completion_with_tool_choice_required - test_text_chat_completion_with_tool_choice_none - test_text_chat_completion_structured_output - test_text_chat_completion_tool_calling_tools_not_in_request ## Test plan == Set up Ollama local server ``` OLLAMA_HOST=127.0.0.1:8321 with-proxy ollama serve OLLAMA_HOST=127.0.0.1:8321 ollama run llama3.2:3b-instruct-fp16 --keepalive 60m ``` == Run a provider test ``` conda activate stack OLLAMA_URL="http://localhost:8321" \ pytest -v -s -k "ollama" --inference-model="llama3.2:3b-instruct-fp16" \ llama_stack/providers/tests/inference/test_text_inference.py::TestInference ``` == Run an e2e test ``` conda activate sherpa with-proxy pip install llama-stack export INFERENCE_MODEL=llama3.2:3b-instruct-fp16 export LLAMA_STACK_PORT=8322 with-proxy llama stack build --template ollama with-proxy llama stack run --env OLLAMA_URL=http://localhost:8321 ollama ``` ``` conda activate stack LLAMA_STACK_PORT=8322 LLAMA_STACK_BASE_URL="http://localhost:8322" \ pytest -v -s --inference-model="llama3.2:3b-instruct-fp16" \ tests/client-sdk/inference/test_text_inference.py ```
This commit is contained in:
parent
9b130f96a7
commit
3a31611486
8 changed files with 479 additions and 223 deletions
|
@ -28,23 +28,17 @@ def provider_tool_format(inference_provider_type):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_weather_tool_definition():
|
||||
return {
|
||||
"tool_name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"location": {
|
||||
"param_type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
|
||||
def test_text_completion_non_streaming(client_with_models, text_model_id):
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
content=tc["content"],
|
||||
stream=False,
|
||||
model_id=text_model_id,
|
||||
sampling_params={
|
||||
|
@ -55,9 +49,17 @@ def test_text_completion_non_streaming(client_with_models, text_model_id):
|
|||
# assert "blue" in response.content.lower().strip()
|
||||
|
||||
|
||||
def test_text_completion_streaming(client_with_models, text_model_id):
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
content=tc["content"],
|
||||
stream=True,
|
||||
model_id=text_model_id,
|
||||
sampling_params={
|
||||
|
@ -70,12 +72,20 @@ def test_text_completion_streaming(client_with_models, text_model_id):
|
|||
assert len(content_str) > 10
|
||||
|
||||
|
||||
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:log_probs",
|
||||
],
|
||||
)
|
||||
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
content=tc["content"],
|
||||
stream=False,
|
||||
model_id=text_model_id,
|
||||
sampling_params={
|
||||
|
@ -90,12 +100,20 @@ def test_completion_log_probs_non_streaming(client_with_models, text_model_id, i
|
|||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||
|
||||
|
||||
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:log_probs",
|
||||
],
|
||||
)
|
||||
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
content=tc["content"],
|
||||
stream=True,
|
||||
model_id=text_model_id,
|
||||
sampling_params={
|
||||
|
@ -114,7 +132,12 @@ def test_completion_log_probs_streaming(client_with_models, text_model_id, infer
|
|||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["completion-01"])
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:structured_output",
|
||||
],
|
||||
)
|
||||
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
name: str
|
||||
|
@ -144,16 +167,17 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"question,expected",
|
||||
"test_case",
|
||||
[
|
||||
("Which planet do humans live on?", "Earth"),
|
||||
(
|
||||
"Which planet has rings around it with a name starting with letter S?",
|
||||
"Saturn",
|
||||
),
|
||||
"inference:chat_completion:non_streaming_01",
|
||||
"inference:chat_completion:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
|
||||
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
|
@ -170,13 +194,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, q
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"question,expected",
|
||||
"test_case",
|
||||
[
|
||||
("What's the name of the Sun in latin?", "Sol"),
|
||||
("What is the name of the US captial?", "Washington"),
|
||||
"inference:chat_completion:streaming_01",
|
||||
"inference:chat_completion:streaming_02",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
|
||||
def test_text_chat_completion_streaming(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[{"role": "user", "content": question}],
|
||||
|
@ -187,18 +215,26 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, quest
|
|||
assert expected.lower() in "".join(streamed_content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, provider_tool_format, test_case
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
messages=tc["messages"],
|
||||
tools=tc["tools"],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=provider_tool_format,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=False,
|
||||
)
|
||||
# No content is returned for the system message since we expect the
|
||||
|
@ -207,8 +243,8 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
|||
assert response.completion_message.role == "assistant"
|
||||
|
||||
assert len(response.completion_message.tool_calls) == 1
|
||||
assert response.completion_message.tool_calls[0].tool_name == "get_weather"
|
||||
assert response.completion_message.tool_calls[0].arguments == {"location": "San Francisco, CA"}
|
||||
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
|
||||
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
|
||||
|
||||
|
||||
# Will extract streamed text and separate it from tool invocation content
|
||||
|
@ -224,57 +260,80 @@ def extract_tool_invocation_content(response):
|
|||
return tool_invocation_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, provider_tool_format, test_case
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
messages=tc["messages"],
|
||||
tools=tc["tools"],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=provider_tool_format,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
expected_tool_name = tc["tools"][0]["tool_name"]
|
||||
expected_argument = tc["expected"]
|
||||
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
client_with_models,
|
||||
text_model_id,
|
||||
get_weather_tool_definition,
|
||||
provider_tool_format,
|
||||
test_case,
|
||||
):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
messages=tc["messages"],
|
||||
tools=tc["tools"],
|
||||
tool_config={
|
||||
"tool_choice": "required",
|
||||
"tool_prompt_format": provider_tool_format,
|
||||
"tool_prompt_format": tool_prompt_format,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
tool_invocation_content = extract_tool_invocation_content(response)
|
||||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
expected_tool_name = tc["tools"][0]["tool_name"]
|
||||
expected_argument = tc["expected"]
|
||||
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
|
||||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_none(
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in San Francisco?"},
|
||||
],
|
||||
tools=[get_weather_tool_definition],
|
||||
messages=tc["messages"],
|
||||
tools=tc["tools"],
|
||||
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
|
||||
stream=True,
|
||||
)
|
||||
|
@ -282,7 +341,12 @@ def test_text_chat_completion_with_tool_choice_none(
|
|||
assert tool_invocation_content == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:structured_output",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
|
@ -309,64 +373,24 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i
|
|||
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"streaming",
|
||||
"test_case",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
"inference:chat_completion:tool_calling_tools_absent",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
|
||||
def test_text_chat_completion_tool_calling_tools_not_in_request(
|
||||
client_with_models, text_model_id, test_case, streaming
|
||||
):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
request = {
|
||||
"model_id": text_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What pods are in the namespace openshift-lightspeed?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "1",
|
||||
"tool_name": "get_object_namespace_list",
|
||||
"arguments": {
|
||||
"kind": "pod",
|
||||
"namespace": "openshift-lightspeed",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"call_id": "1",
|
||||
"tool_name": "get_object_namespace_list",
|
||||
"content": "the objects are pod1, pod2, pod3",
|
||||
},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"tool_name": "get_object_namespace_list",
|
||||
"description": "Get the list of objects in a namespace",
|
||||
"parameters": {
|
||||
"kind": {
|
||||
"param_type": "string",
|
||||
"description": "the type of object",
|
||||
"required": True,
|
||||
},
|
||||
"namespace": {
|
||||
"param_type": "string",
|
||||
"description": "the name of the namespace",
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
"messages": tc["messages"],
|
||||
"tools": tc["tools"],
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": tool_prompt_format,
|
||||
"stream": streaming,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue