llama-stack-mirror/llama_stack/providers/tests/inference/test_text_inference.py
Aidan Do 5d7b611336
Add JSON structured outputs to Ollama Provider (#680)
# What does this PR do?

Addresses issue #679

- Adds support for the response_format field for chat completions and
completions so users can get their outputs in JSON

## Test Plan

<details>

<summary>Integration tests</summary>

`pytest
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output
-k ollama -s -v`

```python
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-ollama] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-ollama] PASSED

================================== 2 passed, 18 deselected, 3 warnings in 41.41s ==================================
```

</details>

<details>
<summary>Manual Tests</summary>

```
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export OLLAMA_INFERENCE_MODEL=llama3.2:3b-instruct-fp16
export LLAMA_STACK_PORT=5000

ollama run $OLLAMA_INFERENCE_MODEL --keepalive 60m
llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \
  --port $LLAMA_STACK_PORT \
  --env INFERENCE_MODEL=$INFERENCE_MODEL \
  --env OLLAMA_URL=http://localhost:11434
```

```python
    client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")

    MODEL_ID=meta-llama/Llama-3.2-3B-Instruct
    prompt =f"""
        Create a step by step plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French.
        You have 3 different operations you can perform. You can create a file, update a file, or delete a file.
        Limit your step by step plan to only these operations per step.
        Don't create more than 10 steps.

        Please ensure there's a README.md file in the root of the codebase that describes the codebase and how to run it.
        Please ensure there's a requirements.txt file in the root of the codebase that describes the dependencies of the codebase.
        """
    response = client.inference.chat_completion(
        model_id=MODEL_ID,
        messages=[
            {"role": "user", "content": prompt},
        ],
        sampling_params={
            "max_tokens": 200000,
        },
        response_format={
            "type": "json_schema",
            "json_schema": {
                "$schema": "http://json-schema.org/draft-07/schema#",
                "title": "Plan",
                "description": f"A plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French.",
                "type": "object",
                "properties": {
                    "steps": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        }
                    }
                },
                "required": ["steps"],
                "additionalProperties": False,
            }
        },
        stream=True,
    )

    content = ""
    for chunk in response:
        if chunk.event.delta:
            print(chunk.event.delta, end="", flush=True)
            content += chunk.event.delta

    try:
        plan = json.loads(content)
        print(plan)
    except Exception as e:
        print(f"Error parsing plan into JSON: {e}")
        plan = {"steps": []}
```

Outputs:

```json
{
    "steps": [
        "Update the requirements.txt file to include the updated dependencies specified in the peer's feedback, including the Google Cloud Translation API key.",
        "Update the app.py file to address the code smells and incorporate the suggested improvements, such as handling errors and exceptions, initializing the Translator object correctly, adding input validation, using type hints and docstrings, and removing unnecessary logging statements.",
        "Create a README.md file that describes the codebase and how to run it.",
        "Ensure the README.md file is up-to-date and accurate.",
        "Update the requirements.txt file to reflect any additional dependencies specified by the peer's feedback.",
        "Add documentation for each function in the app.py file using docstrings.",
        "Implement logging statements throughout the app.py file to monitor application execution.",
        "Test the API endpoint to ensure it correctly translates text from English to French and handles errors properly.",
        "Refactor the code to follow PEP 8 style guidelines and ensure consistency in naming conventions, indentation, and spacing.",
        "Create a new folder for logs and add a logging configuration file (e.g., logconfig.json) that specifies the logging level and output destination.",
        "Deploy the web server on a production environment (e.g., AWS Elastic Beanstalk or Google Cloud Platform) to make it accessible to external users."
    ]
}
```


</details>

## Sources

- Ollama api docs:
https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
- Ollama structured output docs:
https://github.com/ollama/ollama/blob/main/docs/api.md#request-structured-outputs

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [x] Wrote necessary unit or integration tests.
2025-01-02 09:05:51 -08:00

466 lines
16 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from .utils import group_chunks
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
# -m "(fireworks or ollama) and llama_3b"
# --env FIREWORKS_API_KEY=<your_api_key>
def get_expected_stop_reason(model: str):
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
@pytest.fixture
def common_params(inference_model):
return {
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
else ToolPromptFormat.python_list
),
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
class TestInference:
@pytest.mark.asyncio
async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, Model) for model in response)
model_def = None
for model in response:
if model.identifier == inference_model:
model_def = model
break
assert model_def is not None
@pytest.mark.asyncio
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) >= 1
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_completion_logprobs(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
# "remote::nvidia", -- provider doesn't provide all logprobs
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
assert isinstance(response, CompletionResponse)
assert 1 <= len(response.logprobs) <= 5
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert (
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completion_structured_output(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::vllm",
"remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
model_id=inference_model,
content=user_input,
stream=False,
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.model_validate_json(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::vllm",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=False,
**common_params,
)
assert isinstance(response, ChatCompletionResponse)
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
if "Llama3.1" in inference_model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.content, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]