feat: completing text /chat-completion and /completion tests (#1223)

# What does this PR do?

The goal is to have a fairly complete set of provider and e2e tests for
/chat-completion and /completion. This is the current list,
```
grep -oE "def test_[a-zA-Z_+]*" llama_stack/providers/tests/inference/test_text_inference.py | cut -d' ' -f2
```
- test_model_list
- test_text_completion_non_streaming
- test_text_completion_streaming
- test_text_completion_logprobs_non_streaming
- test_text_completion_logprobs_streaming
- test_text_completion_structured_output
- test_text_chat_completion_non_streaming
- test_text_chat_completion_structured_output
- test_text_chat_completion_streaming
- test_text_chat_completion_with_tool_calling
- test_text_chat_completion_with_tool_calling_streaming

```
grep -oE "def test_[a-zA-Z_+]*" tests/client-sdk/inference/test_text_inference.py | cut -d' ' -f2
```
- test_text_completion_non_streaming
- test_text_completion_streaming
- test_text_completion_log_probs_non_streaming
- test_text_completion_log_probs_streaming
- test_text_completion_structured_output
- test_text_chat_completion_non_streaming
- test_text_chat_completion_streaming
- test_text_chat_completion_with_tool_calling_and_non_streaming
- test_text_chat_completion_with_tool_calling_and_streaming
- test_text_chat_completion_with_tool_choice_required
- test_text_chat_completion_with_tool_choice_none
- test_text_chat_completion_structured_output
- test_text_chat_completion_tool_calling_tools_not_in_request

## Test plan

== Set up Ollama local server
```
OLLAMA_HOST=127.0.0.1:8321 with-proxy ollama serve
OLLAMA_HOST=127.0.0.1:8321 ollama run llama3.2:3b-instruct-fp16 --keepalive 60m
```

==  Run a provider test
```
conda activate stack
OLLAMA_URL="http://localhost:8321" \
pytest -v -s -k "ollama" --inference-model="llama3.2:3b-instruct-fp16" \
llama_stack/providers/tests/inference/test_text_inference.py::TestInference
```

== Run an e2e test
```
conda activate sherpa
with-proxy pip install llama-stack
export INFERENCE_MODEL=llama3.2:3b-instruct-fp16
export LLAMA_STACK_PORT=8322
with-proxy llama stack build --template ollama
with-proxy llama stack run --env OLLAMA_URL=http://localhost:8321 ollama
```
```
conda activate stack
LLAMA_STACK_PORT=8322 LLAMA_STACK_BASE_URL="http://localhost:8322" \
pytest -v -s --inference-model="llama3.2:3b-instruct-fp16" \
tests/client-sdk/inference/test_text_inference.py
```
This commit is contained in:
LESSuseLESS 2025-02-25 11:37:04 -08:00 committed by GitHub
parent 9b130f96a7
commit 3a31611486
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 479 additions and 223 deletions

View file

@ -27,8 +27,6 @@ from llama_stack.models.llama.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.tests.test_cases.test_case import TestCase
@ -58,28 +56,6 @@ def common_params(inference_model):
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
class TestInference:
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
@ -100,12 +76,20 @@ class TestInference:
assert model_def is not None
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
async def test_text_completion_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
@ -114,12 +98,24 @@ class TestInference:
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
assert tc["expected"] in response.content
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
@ -133,12 +129,20 @@ class TestInference:
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
async def test_text_completion_logprobs_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
@ -154,10 +158,22 @@ class TestInference:
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_logprobs_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
@ -180,9 +196,14 @@ class TestInference:
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
async def test_text_completion_structured_output(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
class Output(BaseModel):
@ -213,14 +234,20 @@ class TestInference:
assert answer.year_born == expected["year_born"]
assert answer.year_retired == expected["year_retired"]
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
async def test_text_chat_completion_non_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=False,
**common_params,
)
@ -230,9 +257,16 @@ class TestInference:
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
async def test_text_chat_completion_structured_output(
self, inference_model, inference_stack, common_params, test_case
):
inference_impl, _ = inference_stack
class AnswerFormat(BaseModel):
@ -281,14 +315,22 @@ class TestInference:
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
async def test_text_chat_completion_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=True,
**common_params,
)
@ -304,26 +346,28 @@ class TestInference:
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
async def test_text_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=False,
**common_params,
)
@ -339,32 +383,35 @@ class TestInference:
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
async def test_text_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=True,
**common_params,
)
@ -397,6 +444,7 @@ class TestInference:
assert isinstance(last.event.delta.tool_call, ToolCall)
call = last.event.delta.tool_call
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]