llama-stack/llama_stack/providers/tests/inference/test_text_inference.py
Sixian Yi e41873f268
[ez] structured output for /completion ollama & enable tests (#822)
# What does this PR do?

1) enabled structured output for ollama /completion API. It seems we
missed this one.
2) fixed ollama structured output test in client sdk - ollama does not
support list format for structured output
3) enable structured output unit test as the result was stable on
Llama-3.1-8B-Instruct and ollama, fireworks, together.


## Test Plan
1) Run `test_completion_structured_output` on /completion API with 3
providers: ollama, fireworks, together.
pytest -v -s -k "together"
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output


```
(base) sxyi@sxyi-mbp llama-stack % pytest -s -v llama_stack/providers/tests/inference --config=ci_test_config.yaml
/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
================================================================================================ test session starts =================================================================================================
platform darwin -- Python 3.13.0, pytest-8.3.4, pluggy-1.5.0 -- /Library/Frameworks/Python.framework/Versions/3.13/bin/python3.13
cachedir: .pytest_cache
metadata: {'Python': '3.13.0', 'Platform': 'macOS-15.1.1-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'asyncio': '0.24.0', 'html': '4.1.1', 'metadata': '3.1.1', 'md': '0.2.0', 'dependency': '0.6.0', 'md-report': '0.6.3', 'anyio': '4.6.2.post1'}}
rootdir: /Users/sxyi/llama-stack
configfile: pyproject.toml
plugins: asyncio-0.24.0, html-4.1.1, metadata-3.1.1, md-0.2.0, dependency-0.6.0, md-report-0.6.3, anyio-4.6.2.post1
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 85 items / 82 deselected / 3 selected                                                                                                                                                                      

llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct-ollama] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct-fireworks]
PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct-together] PASSED

==================================================================================== 3 passed, 82 deselected, 8 warnings in 5.67s ====================================================================================
```
2)  
` LLAMA_STACK_CONFIG="./llama_stack/templates/ollama/run.yaml"
/opt/miniconda3/envs/stack/bin/pytest -s -v tests/client-sdk/inference`

Before: 
```
________________________________________________________________________________________ test_completion_structured_output __________________________________________________________________________________________
tests/client-sdk/inference/test_inference.py:174: in test_completion_structured_output
    answer = AnswerFormat.model_validate_json(response.content)
E   pydantic_core._pydantic_core.ValidationError: 1 validation error for AnswerFormat
E     Invalid JSON: expected value at line 1 column 2 [type=json_invalid, input_value=' The year he retired, he...5\n\nThe best answer is', input_type=str]
E       For further information visit https://errors.pydantic.dev/2.10/v/json_invalid
```

After: 
test consistently passes


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2025-01-21 21:10:24 -08:00

485 lines
17 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import ListModelsResponse, Model
from .utils import group_chunks
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
# -m "(fireworks or ollama) and llama_3b"
# --env FIREWORKS_API_KEY=<your_api_key>
def get_expected_stop_reason(model: str):
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
@pytest.fixture
def common_params(inference_model):
return {
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
else ToolPromptFormat.python_list
),
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
class TestInference:
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
@pytest.mark.asyncio(loop_scope="session")
async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack
response = await models_impl.list_models()
assert isinstance(response, ListModelsResponse)
assert isinstance(response.data, list)
assert len(response.data) >= 1
assert all(isinstance(model, Model) for model in response.data)
model_def = None
for model in response.data:
if model.identifier == inference_model:
model_def = model
break
assert model_def is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) >= 1
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
# "remote::nvidia", -- provider doesn't provide all logprobs
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
assert isinstance(response, CompletionResponse)
assert 1 <= len(response.logprobs) <= 5
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert (
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_structured_output(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::vllm",
"remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
model_id=inference_model,
content=user_input,
stream=False,
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.model_validate_json(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::vllm",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
if "Llama3.1" in inference_model:
assert all(
chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.content, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]