forked from phoenix-oss/llama-stack-mirror
feat: completing text /chat-completion and /completion tests (#1223)
# What does this PR do? The goal is to have a fairly complete set of provider and e2e tests for /chat-completion and /completion. This is the current list, ``` grep -oE "def test_[a-zA-Z_+]*" llama_stack/providers/tests/inference/test_text_inference.py | cut -d' ' -f2 ``` - test_model_list - test_text_completion_non_streaming - test_text_completion_streaming - test_text_completion_logprobs_non_streaming - test_text_completion_logprobs_streaming - test_text_completion_structured_output - test_text_chat_completion_non_streaming - test_text_chat_completion_structured_output - test_text_chat_completion_streaming - test_text_chat_completion_with_tool_calling - test_text_chat_completion_with_tool_calling_streaming ``` grep -oE "def test_[a-zA-Z_+]*" tests/client-sdk/inference/test_text_inference.py | cut -d' ' -f2 ``` - test_text_completion_non_streaming - test_text_completion_streaming - test_text_completion_log_probs_non_streaming - test_text_completion_log_probs_streaming - test_text_completion_structured_output - test_text_chat_completion_non_streaming - test_text_chat_completion_streaming - test_text_chat_completion_with_tool_calling_and_non_streaming - test_text_chat_completion_with_tool_calling_and_streaming - test_text_chat_completion_with_tool_choice_required - test_text_chat_completion_with_tool_choice_none - test_text_chat_completion_structured_output - test_text_chat_completion_tool_calling_tools_not_in_request ## Test plan == Set up Ollama local server ``` OLLAMA_HOST=127.0.0.1:8321 with-proxy ollama serve OLLAMA_HOST=127.0.0.1:8321 ollama run llama3.2:3b-instruct-fp16 --keepalive 60m ``` == Run a provider test ``` conda activate stack OLLAMA_URL="http://localhost:8321" \ pytest -v -s -k "ollama" --inference-model="llama3.2:3b-instruct-fp16" \ llama_stack/providers/tests/inference/test_text_inference.py::TestInference ``` == Run an e2e test ``` conda activate sherpa with-proxy pip install llama-stack export INFERENCE_MODEL=llama3.2:3b-instruct-fp16 export LLAMA_STACK_PORT=8322 with-proxy llama stack build --template ollama with-proxy llama stack run --env OLLAMA_URL=http://localhost:8321 ollama ``` ``` conda activate stack LLAMA_STACK_PORT=8322 LLAMA_STACK_BASE_URL="http://localhost:8322" \ pytest -v -s --inference-model="llama3.2:3b-instruct-fp16" \ tests/client-sdk/inference/test_text_inference.py ```
This commit is contained in:
parent
9b130f96a7
commit
3a31611486
8 changed files with 479 additions and 223 deletions
|
@ -27,8 +27,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
||||
|
@ -58,28 +56,6 @@ def common_params(inference_model):
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_definition():
|
||||
return ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get the current weather",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestInference:
|
||||
# Session scope for asyncio because the tests in this class all
|
||||
# share the same provider instance.
|
||||
|
@ -100,12 +76,20 @@ class TestInference:
|
|||
|
||||
assert model_def is not None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:non_streaming",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_completion(self, inference_model, inference_stack):
|
||||
async def test_text_completion_non_streaming(self, inference_model, inference_stack, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
content=tc["content"],
|
||||
stream=False,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
|
@ -114,12 +98,24 @@ class TestInference:
|
|||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "1963" in response.content
|
||||
assert tc["expected"] in response.content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:streaming",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_text_completion_streaming(self, inference_model, inference_stack, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
content=tc["content"],
|
||||
stream=True,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
|
@ -133,12 +129,20 @@ class TestInference:
|
|||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:logprobs_non_streaming",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_completion_logprobs(self, inference_model, inference_stack):
|
||||
async def test_text_completion_logprobs_non_streaming(self, inference_model, inference_stack, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
content=tc["content"],
|
||||
stream=False,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
|
@ -154,10 +158,22 @@ class TestInference:
|
|||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:logprobs_streaming",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_text_completion_logprobs_streaming(self, inference_model, inference_stack, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
content=tc["content"],
|
||||
stream=True,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
|
@ -180,9 +196,14 @@ class TestInference:
|
|||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["completion-01"])
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:structured_output",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
|
||||
async def test_text_completion_structured_output(self, inference_model, inference_stack, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class Output(BaseModel):
|
||||
|
@ -213,14 +234,20 @@ class TestInference:
|
|||
assert answer.year_born == expected["year_born"]
|
||||
assert answer.year_retired == expected["year_retired"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:sample_messages",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_non_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
async def test_text_chat_completion_non_streaming(self, inference_model, inference_stack, common_params, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
tc = TestCase(test_case)
|
||||
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
messages=sample_messages,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
**common_params,
|
||||
)
|
||||
|
@ -230,9 +257,16 @@ class TestInference:
|
|||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:structured_output",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
|
||||
async def test_text_chat_completion_structured_output(
|
||||
self, inference_model, inference_stack, common_params, test_case
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
|
@ -281,14 +315,22 @@ class TestInference:
|
|||
with pytest.raises(ValidationError):
|
||||
AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:sample_messages",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
|
||||
async def test_text_chat_completion_streaming(self, inference_model, inference_stack, common_params, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
tc = TestCase(test_case)
|
||||
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
messages=sample_messages,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**common_params,
|
||||
)
|
||||
|
@ -304,26 +346,28 @@ class TestInference:
|
|||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
assert end.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:sample_messages_tool_calling",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
async def test_text_chat_completion_with_tool_calling(
|
||||
self,
|
||||
inference_model,
|
||||
inference_stack,
|
||||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
test_case,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
tc = TestCase(test_case)
|
||||
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
tools=tc["tools"],
|
||||
stream=False,
|
||||
**common_params,
|
||||
)
|
||||
|
@ -339,32 +383,35 @@ class TestInference:
|
|||
assert len(message.tool_calls) > 0
|
||||
|
||||
call = message.tool_calls[0]
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
assert call.tool_name == tc["tools"][0]["tool_name"]
|
||||
for name, value in tc["expected"].items():
|
||||
assert name in call.arguments
|
||||
assert value in call.arguments[name]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:sample_messages_tool_calling",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
async def test_text_chat_completion_with_tool_calling_streaming(
|
||||
self,
|
||||
inference_model,
|
||||
inference_stack,
|
||||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
test_case,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
tc = TestCase(test_case)
|
||||
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
|
||||
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
tools=tc["tools"],
|
||||
stream=True,
|
||||
**common_params,
|
||||
)
|
||||
|
@ -397,6 +444,7 @@ class TestInference:
|
|||
assert isinstance(last.event.delta.tool_call, ToolCall)
|
||||
|
||||
call = last.event.delta.tool_call
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
assert call.tool_name == tc["tools"][0]["tool_name"]
|
||||
for name, value in tc["expected"].items():
|
||||
assert name in call.arguments
|
||||
assert value in call.arguments[name]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue