feat: completing text /chat-completion and /completion tests (#1223)

# What does this PR do?

The goal is to have a fairly complete set of provider and e2e tests for
/chat-completion and /completion. This is the current list,
```
grep -oE "def test_[a-zA-Z_+]*" llama_stack/providers/tests/inference/test_text_inference.py | cut -d' ' -f2
```
- test_model_list
- test_text_completion_non_streaming
- test_text_completion_streaming
- test_text_completion_logprobs_non_streaming
- test_text_completion_logprobs_streaming
- test_text_completion_structured_output
- test_text_chat_completion_non_streaming
- test_text_chat_completion_structured_output
- test_text_chat_completion_streaming
- test_text_chat_completion_with_tool_calling
- test_text_chat_completion_with_tool_calling_streaming

```
grep -oE "def test_[a-zA-Z_+]*" tests/client-sdk/inference/test_text_inference.py | cut -d' ' -f2
```
- test_text_completion_non_streaming
- test_text_completion_streaming
- test_text_completion_log_probs_non_streaming
- test_text_completion_log_probs_streaming
- test_text_completion_structured_output
- test_text_chat_completion_non_streaming
- test_text_chat_completion_streaming
- test_text_chat_completion_with_tool_calling_and_non_streaming
- test_text_chat_completion_with_tool_calling_and_streaming
- test_text_chat_completion_with_tool_choice_required
- test_text_chat_completion_with_tool_choice_none
- test_text_chat_completion_structured_output
- test_text_chat_completion_tool_calling_tools_not_in_request

## Test plan

== Set up Ollama local server
```
OLLAMA_HOST=127.0.0.1:8321 with-proxy ollama serve
OLLAMA_HOST=127.0.0.1:8321 ollama run llama3.2:3b-instruct-fp16 --keepalive 60m
```

==  Run a provider test
```
conda activate stack
OLLAMA_URL="http://localhost:8321" \
pytest -v -s -k "ollama" --inference-model="llama3.2:3b-instruct-fp16" \
llama_stack/providers/tests/inference/test_text_inference.py::TestInference
```

== Run an e2e test
```
conda activate sherpa
with-proxy pip install llama-stack
export INFERENCE_MODEL=llama3.2:3b-instruct-fp16
export LLAMA_STACK_PORT=8322
with-proxy llama stack build --template ollama
with-proxy llama stack run --env OLLAMA_URL=http://localhost:8321 ollama
```
```
conda activate stack
LLAMA_STACK_PORT=8322 LLAMA_STACK_BASE_URL="http://localhost:8322" \
pytest -v -s --inference-model="llama3.2:3b-instruct-fp16" \
tests/client-sdk/inference/test_text_inference.py
```
This commit is contained in:
LESSuseLESS 2025-02-25 11:37:04 -08:00 committed by GitHub
parent 9b130f96a7
commit 3a31611486
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 479 additions and 223 deletions

View file

@ -3,4 +3,4 @@ include distributions/dependencies.json
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml
include llama_stack/providers/tests/test_cases/*.json
include llama_stack/providers/tests/test_cases/inference/*.json

View file

@ -27,8 +27,6 @@ from llama_stack.models.llama.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.tests.test_cases.test_case import TestCase
@ -58,28 +56,6 @@ def common_params(inference_model):
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
class TestInference:
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
@ -100,12 +76,20 @@ class TestInference:
assert model_def is not None
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
async def test_text_completion_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
@ -114,12 +98,24 @@ class TestInference:
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
assert tc["expected"] in response.content
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
@ -133,12 +129,20 @@ class TestInference:
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
async def test_text_completion_logprobs_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
@ -154,10 +158,22 @@ class TestInference:
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_logprobs_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
@ -180,9 +196,14 @@ class TestInference:
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
async def test_text_completion_structured_output(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
class Output(BaseModel):
@ -213,14 +234,20 @@ class TestInference:
assert answer.year_born == expected["year_born"]
assert answer.year_retired == expected["year_retired"]
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
async def test_text_chat_completion_non_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=False,
**common_params,
)
@ -230,9 +257,16 @@ class TestInference:
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
async def test_text_chat_completion_structured_output(
self, inference_model, inference_stack, common_params, test_case
):
inference_impl, _ = inference_stack
class AnswerFormat(BaseModel):
@ -281,14 +315,22 @@ class TestInference:
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
async def test_text_chat_completion_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=True,
**common_params,
)
@ -304,26 +346,28 @@ class TestInference:
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
async def test_text_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=False,
**common_params,
)
@ -339,32 +383,35 @@ class TestInference:
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
async def test_text_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=True,
**common_params,
)
@ -397,6 +444,7 @@ class TestInference:
assert isinstance(last.event.delta.tool_call, ToolCall)
call = last.event.delta.tool_call
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]

View file

@ -1,24 +0,0 @@
{
"01": {
"name": "structured output",
"data": {
"notes": "We include context about Michael Jordan in the prompt so that the test is focused on the funtionality of the model and not on the information embedded in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
},
{
"role": "user",
"content": "Please give me information about Michael Jordan."
}
],
"expected": {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}
}
}
}

View file

@ -1,13 +0,0 @@
{
"01": {
"name": "structured output",
"data": {
"user_input": "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.",
"expected": {
"name": "Michael Jordan",
"year_born": "1963",
"year_retired": "2003"
}
}
}
}

View file

@ -0,0 +1,171 @@
{
"non_streaming_01": {
"data": {
"question": "Which planet do humans live on?",
"expected": "Earth"
}
},
"non_streaming_02": {
"data": {
"question": "Which planet has rings around it with a name starting with letter S?",
"expected": "Saturn"
}
},
"sample_messages": {
"data": {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What's the weather like today?"
}
]
}
},
"streaming_01": {
"data": {
"question": "What's the name of the Sun in latin?",
"expected": "Sol"
}
},
"streaming_02": {
"data": {
"question": "What is the name of the US captial?",
"expected": "Washington"
}
},
"tool_calling": {
"data": {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"}
],
"tools": [
{
"tool_name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
}
}
],
"expected": {
"location": "San Francisco, CA"
}
}
},
"sample_messages_tool_calling": {
"data": {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What's the weather like today?"
},
{
"role": "user",
"content": "What's the weather like in San Francisco?"
}
],
"tools": [
{
"tool_name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state, e.g. San Francisco, CA",
"required": true
}
}
}
],
"expected": {
"location": "San Francisco"
}
}
},
"structured_output": {
"data": {
"notes": "We include context about Michael Jordan in the prompt so that the test is focused on the funtionality of the model and not on the information embedded in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
},
{
"role": "user",
"content": "Please give me information about Michael Jordan."
}
],
"expected": {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}
}
},
"tool_calling_tools_absent": {
"data": {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What pods are in the namespace openshift-lightspeed?"
},
{
"role": "assistant",
"content": "",
"stop_reason": "end_of_turn",
"tool_calls": [
{
"call_id": "1",
"tool_name": "get_object_namespace_list",
"arguments": {
"kind": "pod",
"namespace": "openshift-lightspeed"
}
}
]
},
{
"role": "tool",
"call_id": "1",
"tool_name": "get_object_namespace_list",
"content": "the objects are pod1, pod2, pod3"
}
],
"tools": [
{
"tool_name": "get_object_namespace_list",
"description": "Get the list of objects in a namespace",
"parameters": {
"kind": {
"param_type": "string",
"description": "the type of object",
"required": true
},
"namespace": {
"param_type": "string",
"description": "the name of the namespace",
"required": true
}
}
}
]
}
}
}

View file

@ -0,0 +1,43 @@
{
"sanity": {
"data": {
"content": "Complete the sentence using one word: Roses are red, violets are "
}
},
"non_streaming": {
"data": {
"content": "Micheael Jordan is born in ",
"expected": "1963"
}
},
"streaming": {
"data": {
"content": "Roses are red,"
}
},
"log_probs": {
"data": {
"content": "Complete the sentence: Micheael Jordan is born in "
}
},
"logprobs_non_streaming": {
"data": {
"content": "Micheael Jordan is born in "
}
},
"logprobs_streaming": {
"data": {
"content": "Roses are red,"
}
},
"structured_output": {
"data": {
"user_input": "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.",
"expected": {
"name": "Michael Jordan",
"year_born": "1963",
"year_retired": "2003"
}
}
}
}

View file

@ -9,7 +9,10 @@ import pathlib
class TestCase:
_apis = ["chat_completion", "completion"]
_apis = [
"inference/chat_completion",
"inference/completion",
]
_jsonblob = {}
def __init__(self, name):
@ -17,7 +20,12 @@ class TestCase:
if self._jsonblob == {}:
for api in self._apis:
with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f:
TestCase._jsonblob.update({f"{api}-{k}": v for k, v in json.load(f).items()})
coloned = api.replace("/", ":")
try:
loaded = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"There is a syntax error in {api}.json: {e}") from e
TestCase._jsonblob.update({f"{coloned}:{k}": v for k, v in loaded.items()})
# loading this test case
tc = self._jsonblob.get(name)
@ -25,7 +33,6 @@ class TestCase:
raise ValueError(f"Test case {name} not found")
# these are the only fields we need
self.name = tc.get("name")
self.data = tc.get("data")
def __getitem__(self, key):

View file

@ -28,23 +28,17 @@ def provider_tool_format(inference_provider_type):
)
@pytest.fixture
def get_weather_tool_definition():
return {
"tool_name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
}
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
def test_text_completion_non_streaming(client_with_models, text_model_id):
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
content=tc["content"],
stream=False,
model_id=text_model_id,
sampling_params={
@ -55,9 +49,17 @@ def test_text_completion_non_streaming(client_with_models, text_model_id):
# assert "blue" in response.content.lower().strip()
def test_text_completion_streaming(client_with_models, text_model_id):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
@ -70,12 +72,20 @@ def test_text_completion_streaming(client_with_models, text_model_id):
assert len(content_str) > 10
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=text_model_id,
sampling_params={
@ -90,12 +100,20 @@ def test_completion_log_probs_non_streaming(client_with_models, text_model_id, i
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
tc = TestCase(test_case)
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
content=tc["content"],
stream=True,
model_id=text_model_id,
sampling_params={
@ -114,7 +132,12 @@ def test_completion_log_probs_streaming(client_with_models, text_model_id, infer
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:structured_output",
],
)
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
name: str
@ -144,16 +167,17 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te
@pytest.mark.parametrize(
"question,expected",
"test_case",
[
("Which planet do humans live on?", "Earth"),
(
"Which planet has rings around it with a name starting with letter S?",
"Saturn",
),
"inference:chat_completion:non_streaming_01",
"inference:chat_completion:non_streaming_02",
],
)
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
@ -170,13 +194,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, q
@pytest.mark.parametrize(
"question,expected",
"test_case",
[
("What's the name of the Sun in latin?", "Sol"),
("What is the name of the US captial?", "Washington"),
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
def test_text_chat_completion_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[{"role": "user", "content": question}],
@ -187,18 +215,26 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, quest
assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_choice="auto",
tool_prompt_format=provider_tool_format,
tool_prompt_format=tool_prompt_format,
stream=False,
)
# No content is returned for the system message since we expect the
@ -207,8 +243,8 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
assert response.completion_message.tool_calls[0].tool_name == "get_weather"
assert response.completion_message.tool_calls[0].arguments == {"location": "San Francisco, CA"}
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
# Will extract streamed text and separate it from tool invocation content
@ -224,57 +260,80 @@ def extract_tool_invocation_content(response):
return tool_invocation_content
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_streaming(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, provider_tool_format, test_case
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_choice="auto",
tool_prompt_format=provider_tool_format,
tool_prompt_format=tool_prompt_format,
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
expected_tool_name = tc["tools"][0]["tool_name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_required(
client_with_models,
text_model_id,
get_weather_tool_definition,
provider_tool_format,
test_case,
):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_config={
"tool_choice": "required",
"tool_prompt_format": provider_tool_format,
"tool_prompt_format": tool_prompt_format,
},
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
expected_tool_name = tc["tools"][0]["tool_name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
def test_text_chat_completion_with_tool_choice_none(
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"},
],
tools=[get_weather_tool_definition],
messages=tc["messages"],
tools=tc["tools"],
tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format},
stream=True,
)
@ -282,7 +341,12 @@ def test_text_chat_completion_with_tool_choice_none(
assert tool_invocation_content == ""
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:structured_output",
],
)
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
first_name: str
@ -309,64 +373,24 @@ def test_text_chat_completion_structured_output(client_with_models, text_model_i
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize(
"streaming",
"test_case",
[
True,
False,
"inference:chat_completion:tool_calling_tools_absent",
],
)
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
def test_text_chat_completion_tool_calling_tools_not_in_request(
client_with_models, text_model_id, test_case, streaming
):
tc = TestCase(test_case)
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = {
"model_id": text_model_id,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "What pods are in the namespace openshift-lightspeed?",
},
{
"role": "assistant",
"content": "",
"stop_reason": "end_of_turn",
"tool_calls": [
{
"call_id": "1",
"tool_name": "get_object_namespace_list",
"arguments": {
"kind": "pod",
"namespace": "openshift-lightspeed",
},
}
],
},
{
"role": "tool",
"call_id": "1",
"tool_name": "get_object_namespace_list",
"content": "the objects are pod1, pod2, pod3",
},
],
"tools": [
{
"tool_name": "get_object_namespace_list",
"description": "Get the list of objects in a namespace",
"parameters": {
"kind": {
"param_type": "string",
"description": "the type of object",
"required": True,
},
"namespace": {
"param_type": "string",
"description": "the name of the namespace",
"required": True,
},
},
}
],
"messages": tc["messages"],
"tools": tc["tools"],
"tool_choice": "auto",
"tool_prompt_format": tool_prompt_format,
"stream": streaming,