feat: completing text /chat-completion and /completion tests (#1223)

# What does this PR do?

The goal is to have a fairly complete set of provider and e2e tests for
/chat-completion and /completion. This is the current list,
```
grep -oE "def test_[a-zA-Z_+]*" llama_stack/providers/tests/inference/test_text_inference.py | cut -d' ' -f2
```
- test_model_list
- test_text_completion_non_streaming
- test_text_completion_streaming
- test_text_completion_logprobs_non_streaming
- test_text_completion_logprobs_streaming
- test_text_completion_structured_output
- test_text_chat_completion_non_streaming
- test_text_chat_completion_structured_output
- test_text_chat_completion_streaming
- test_text_chat_completion_with_tool_calling
- test_text_chat_completion_with_tool_calling_streaming

```
grep -oE "def test_[a-zA-Z_+]*" tests/client-sdk/inference/test_text_inference.py | cut -d' ' -f2
```
- test_text_completion_non_streaming
- test_text_completion_streaming
- test_text_completion_log_probs_non_streaming
- test_text_completion_log_probs_streaming
- test_text_completion_structured_output
- test_text_chat_completion_non_streaming
- test_text_chat_completion_streaming
- test_text_chat_completion_with_tool_calling_and_non_streaming
- test_text_chat_completion_with_tool_calling_and_streaming
- test_text_chat_completion_with_tool_choice_required
- test_text_chat_completion_with_tool_choice_none
- test_text_chat_completion_structured_output
- test_text_chat_completion_tool_calling_tools_not_in_request

## Test plan

== Set up Ollama local server
```
OLLAMA_HOST=127.0.0.1:8321 with-proxy ollama serve
OLLAMA_HOST=127.0.0.1:8321 ollama run llama3.2:3b-instruct-fp16 --keepalive 60m
```

==  Run a provider test
```
conda activate stack
OLLAMA_URL="http://localhost:8321" \
pytest -v -s -k "ollama" --inference-model="llama3.2:3b-instruct-fp16" \
llama_stack/providers/tests/inference/test_text_inference.py::TestInference
```

== Run an e2e test
```
conda activate sherpa
with-proxy pip install llama-stack
export INFERENCE_MODEL=llama3.2:3b-instruct-fp16
export LLAMA_STACK_PORT=8322
with-proxy llama stack build --template ollama
with-proxy llama stack run --env OLLAMA_URL=http://localhost:8321 ollama
```
```
conda activate stack
LLAMA_STACK_PORT=8322 LLAMA_STACK_BASE_URL="http://localhost:8322" \
pytest -v -s --inference-model="llama3.2:3b-instruct-fp16" \
tests/client-sdk/inference/test_text_inference.py
```
This commit is contained in:
LESSuseLESS 2025-02-25 11:37:04 -08:00 committed by GitHub
parent 9b130f96a7
commit 3a31611486
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 479 additions and 223 deletions

View file

@ -28,23 +28,17 @@ def provider_tool_format(inference_provider_type):
)
@pytest.fixture
def get_weather_tool_definition():
return {
"tool_name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
}
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
def test_text_completion_non_streaming(client_with_models, text_model_id):
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
content=tc["content"],
stream=False,
model_id=text_model_id,
sampling_params={
@ -55,9 +49,17 @@ def test_text_completion_non_streaming(client_with_models, text_model_id):
# assert "blue" in response.content.lower().strip()
def test_text_completion_streaming(client_with_models, text_model_id):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
@ -70,12 +72,20 @@ def test_text_completion_streaming(client_with_models, text_model_id):
assert len(content_str) > 10
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=text_model_id,
sampling_params={
@ -90,12 +100,20 @@ def test_completion_log_probs_non_streaming(client_with_models, text_model_id, i
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
@ -114,7 +132,12 @@ def test_completion_log_probs_streaming(client_with_models, text_model_id, infer
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:structured_output",
],
)
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
name: str
@ -144,16 +167,17 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te
@pytest.mark.parametrize(
"question,expected",
"test_case",
[
("Which planet do humans live on?", "Earth"),
(
"Which planet has rings around it with a name starting with letter S?",
"Saturn",
),
"inference:chat_completion:non_streaming_01",
"inference:chat_completion:non_streaming_02",
],
)
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
@ -170,13 +194,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, q
@pytest.mark.parametrize(
"question,expected",
"test_case",
[
("What's the name of the Sun in latin?", "Sol"),
("What is the name of the US captial?", "Washington"),
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
def test_text_chat_completion_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[{"role": "user", "content": question}],
@ -187,18 +215,26 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, quest
assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_choice="auto",
tool_prompt_format=provider_tool_format,
tool_prompt_format=tool_prompt_format,
stream=False,
)
# No content is returned for the system message since we expect the
@ -207,8 +243,8 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
assert response.completion_message.tool_calls[0].tool_name == "get_weather"
assert response.completion_message.tool_calls[0].arguments == {"location": "San Francisco, CA"}
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
# Will extract streamed text and separate it from tool invocation content
@ -224,57 +260,80 @@ def extract_tool_invocation_content(response):
return tool_invocation_content
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_streaming(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_choice="auto",
tool_prompt_format=provider_tool_format,
tool_prompt_format=tool_prompt_format,
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
expected_tool_name = tc["tools"][0]["tool_name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_required(
client_with_models,
text_model_id,
get_weather_tool_definition,
provider_tool_format,
test_case,
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_config={
"tool_choice": "required",
"tool_prompt_format": provider_tool_format,
"tool_prompt_format": tool_prompt_format,
},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
expected_tool_name = tc["tools"][0]["tool_name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
def test_text_chat_completion_with_tool_choice_none(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
stream=True,
)
@ -282,7 +341,12 @@ def test_text_chat_completion_with_tool_choice_none(
assert tool_invocation_content == ""
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:structured_output",
],
)
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
first_name: str
@ -309,64 +373,24 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize(
"streaming",
"test_case",
[
True,
False,
"inference:chat_completion:tool_calling_tools_absent",
],
)
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
def test_text_chat_completion_tool_calling_tools_not_in_request(
client_with_models, text_model_id, test_case, streaming
):
tc = TestCase(test_case)
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = {
"model_id": text_model_id,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "What pods are in the namespace openshift-lightspeed?",
},
{
"role": "assistant",
"content": "",
"stop_reason": "end_of_turn",
"tool_calls": [
{
"call_id": "1",
"tool_name": "get_object_namespace_list",
"arguments": {
"kind": "pod",
"namespace": "openshift-lightspeed",
},
}
],
},
{
"role": "tool",
"call_id": "1",
"tool_name": "get_object_namespace_list",
"content": "the objects are pod1, pod2, pod3",
},
],
"tools": [
{
"tool_name": "get_object_namespace_list",
"description": "Get the list of objects in a namespace",
"parameters": {
"kind": {
"param_type": "string",
"description": "the type of object",
"required": True,
},
"namespace": {
"param_type": "string",
"description": "the name of the namespace",
"required": True,
},
},
}
],
"messages": tc["messages"],
"tools": tc["tools"],
"tool_choice": "auto",
"tool_prompt_format": tool_prompt_format,
"stream": streaming,