mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? Addresses issue #679 - Adds support for the response_format field for chat completions and completions so users can get their outputs in JSON ## Test Plan <details> <summary>Integration tests</summary> `pytest llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output -k ollama -s -v` ```python llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-ollama] PASSED ================================== 2 passed, 18 deselected, 3 warnings in 41.41s ================================== ``` </details> <details> <summary>Manual Tests</summary> ``` export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export OLLAMA_INFERENCE_MODEL=llama3.2:3b-instruct-fp16 export LLAMA_STACK_PORT=5000 ollama run $OLLAMA_INFERENCE_MODEL --keepalive 60m llama stack build --template ollama --image-type conda llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env OLLAMA_URL=http://localhost:11434 ``` ```python client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}") MODEL_ID=meta-llama/Llama-3.2-3B-Instruct prompt =f""" Create a step by step plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French. You have 3 different operations you can perform. You can create a file, update a file, or delete a file. Limit your step by step plan to only these operations per step. Don't create more than 10 steps. Please ensure there's a README.md file in the root of the codebase that describes the codebase and how to run it. Please ensure there's a requirements.txt file in the root of the codebase that describes the dependencies of the codebase. """ response = client.inference.chat_completion( model_id=MODEL_ID, messages=[ {"role": "user", "content": prompt}, ], sampling_params={ "max_tokens": 200000, }, response_format={ "type": "json_schema", "json_schema": { "$schema": "http://json-schema.org/draft-07/schema#", "title": "Plan", "description": f"A plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French.", "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "string" } } }, "required": ["steps"], "additionalProperties": False, } }, stream=True, ) content = "" for chunk in response: if chunk.event.delta: print(chunk.event.delta, end="", flush=True) content += chunk.event.delta try: plan = json.loads(content) print(plan) except Exception as e: print(f"Error parsing plan into JSON: {e}") plan = {"steps": []} ``` Outputs: ```json { "steps": [ "Update the requirements.txt file to include the updated dependencies specified in the peer's feedback, including the Google Cloud Translation API key.", "Update the app.py file to address the code smells and incorporate the suggested improvements, such as handling errors and exceptions, initializing the Translator object correctly, adding input validation, using type hints and docstrings, and removing unnecessary logging statements.", "Create a README.md file that describes the codebase and how to run it.", "Ensure the README.md file is up-to-date and accurate.", "Update the requirements.txt file to reflect any additional dependencies specified by the peer's feedback.", "Add documentation for each function in the app.py file using docstrings.", "Implement logging statements throughout the app.py file to monitor application execution.", "Test the API endpoint to ensure it correctly translates text from English to French and handles errors properly.", "Refactor the code to follow PEP 8 style guidelines and ensure consistency in naming conventions, indentation, and spacing.", "Create a new folder for logs and add a logging configuration file (e.g., logconfig.json) that specifies the logging level and output destination.", "Deploy the web server on a production environment (e.g., AWS Elastic Beanstalk or Google Cloud Platform) to make it accessible to external users." ] } ``` </details> ## Sources - Ollama api docs: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion - Ollama structured output docs: https://github.com/ollama/ollama/blob/main/docs/api.md#request-structured-outputs ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests.
466 lines
16 KiB
Python
466 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_models.llama3.api.datatypes import (
|
|
SamplingParams,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolParamDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
SystemMessage,
|
|
ToolCallDelta,
|
|
ToolCallParseStatus,
|
|
ToolChoice,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from .utils import group_chunks
|
|
|
|
|
|
# How to run this test:
|
|
#
|
|
# pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py
|
|
# -m "(fireworks or ollama) and llama_3b"
|
|
# --env FIREWORKS_API_KEY=<your_api_key>
|
|
|
|
|
|
def get_expected_stop_reason(model: str):
|
|
return (
|
|
StopReason.end_of_message
|
|
if ("Llama3.1" in model or "Llama-3.1" in model)
|
|
else StopReason.end_of_turn
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def common_params(inference_model):
|
|
return {
|
|
"tool_choice": ToolChoice.auto,
|
|
"tool_prompt_format": (
|
|
ToolPromptFormat.json
|
|
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
|
|
else ToolPromptFormat.python_list
|
|
),
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_messages():
|
|
return [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="What's the weather like today?"),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_tool_definition():
|
|
return ToolDefinition(
|
|
tool_name="get_weather",
|
|
description="Get the current weather",
|
|
parameters={
|
|
"location": ToolParamDefinition(
|
|
param_type="string",
|
|
description="The city and state, e.g. San Francisco, CA",
|
|
),
|
|
},
|
|
)
|
|
|
|
|
|
class TestInference:
|
|
@pytest.mark.asyncio
|
|
async def test_model_list(self, inference_model, inference_stack):
|
|
_, models_impl = inference_stack
|
|
response = await models_impl.list_models()
|
|
assert isinstance(response, list)
|
|
assert len(response) >= 1
|
|
assert all(isinstance(model, Model) for model in response)
|
|
|
|
model_def = None
|
|
for model in response:
|
|
if model.identifier == inference_model:
|
|
model_def = model
|
|
break
|
|
|
|
assert model_def is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::fireworks",
|
|
"remote::nvidia",
|
|
"remote::cerebras",
|
|
):
|
|
pytest.skip("Other inference providers don't support completion() yet")
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert "1963" in response.content
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert len(chunks) >= 1
|
|
last = chunks[-1]
|
|
assert last.stop_reason == StopReason.out_of_tokens
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_logprobs(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
# "remote::nvidia", -- provider doesn't provide all logprobs
|
|
):
|
|
pytest.skip("Other inference providers don't support completion() yet")
|
|
|
|
response = await inference_impl.completion(
|
|
content="Micheael Jordan is born in ",
|
|
stream=False,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
|
|
assert isinstance(response, CompletionResponse)
|
|
assert 1 <= len(response.logprobs) <= 5
|
|
assert response.logprobs, "Logprobs should not be empty"
|
|
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
|
|
|
chunks = [
|
|
r
|
|
async for r in await inference_impl.completion(
|
|
content="Roses are red,",
|
|
stream=True,
|
|
model_id=inference_model,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=5,
|
|
),
|
|
logprobs=LogProbConfig(
|
|
top_k=3,
|
|
),
|
|
)
|
|
]
|
|
|
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
|
assert (
|
|
1 <= len(chunks) <= 6
|
|
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
|
for chunk in chunks:
|
|
if chunk.delta: # if there's a token, we expect logprobs
|
|
assert chunk.logprobs, "Logprobs should not be empty"
|
|
assert all(
|
|
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
|
)
|
|
else: # no token, no logprobs
|
|
assert not chunk.logprobs, "Logprobs should be empty"
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip("This test is not quite robust")
|
|
async def test_completion_structured_output(self, inference_model, inference_stack):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::fireworks",
|
|
"remote::nvidia",
|
|
"remote::vllm",
|
|
"remote::cerebras",
|
|
):
|
|
pytest.skip(
|
|
"Other inference providers don't support structured output in completions yet"
|
|
)
|
|
|
|
class Output(BaseModel):
|
|
name: str
|
|
year_born: str
|
|
year_retired: str
|
|
|
|
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
|
response = await inference_impl.completion(
|
|
model_id=inference_model,
|
|
content=user_input,
|
|
stream=False,
|
|
sampling_params=SamplingParams(
|
|
max_tokens=50,
|
|
),
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=Output.model_json_schema(),
|
|
),
|
|
)
|
|
assert isinstance(response, CompletionResponse)
|
|
assert isinstance(response.content, str)
|
|
|
|
answer = Output.model_validate_json(response.content)
|
|
assert answer.name == "Michael Jordan"
|
|
assert answer.year_born == "1963"
|
|
assert answer.year_retired == "2003"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_non_streaming(
|
|
self, inference_model, inference_stack, common_params, sample_messages
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
assert len(response.completion_message.content) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output(
|
|
self, inference_model, inference_stack, common_params
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
|
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
if provider.__provider_spec__.provider_type not in (
|
|
"inline::meta-reference",
|
|
"remote::ollama",
|
|
"remote::fireworks",
|
|
"remote::tgi",
|
|
"remote::together",
|
|
"remote::vllm",
|
|
"remote::nvidia",
|
|
):
|
|
pytest.skip("Other inference providers don't support structured output yet")
|
|
|
|
class AnswerFormat(BaseModel):
|
|
first_name: str
|
|
last_name: str
|
|
year_of_birth: int
|
|
num_seasons_in_nba: int
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
# we include context about Michael Jordan in the prompt so that the test is
|
|
# focused on the funtionality of the model and not on the information embedded
|
|
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
|
|
SystemMessage(
|
|
content=(
|
|
"You are a helpful assistant.\n\n"
|
|
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
|
)
|
|
),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
response_format=JsonSchemaResponseFormat(
|
|
json_schema=AnswerFormat.model_json_schema(),
|
|
),
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert response.completion_message.role == "assistant"
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
|
assert answer.first_name == "Michael"
|
|
assert answer.last_name == "Jordan"
|
|
assert answer.year_of_birth == 1963
|
|
assert answer.num_seasons_in_nba == 15
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
UserMessage(content="Please give me information about Michael Jordan."),
|
|
],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
assert isinstance(response.completion_message.content, str)
|
|
|
|
with pytest.raises(ValidationError):
|
|
AnswerFormat.model_validate_json(response.completion_message.content)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_streaming(
|
|
self, inference_model, inference_stack, common_params, sample_messages
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=sample_messages,
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
|
|
assert len(response) > 0
|
|
assert all(
|
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
|
)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
assert end.event.stop_reason == StopReason.end_of_turn
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_with_tool_calling(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=False,
|
|
**common_params,
|
|
)
|
|
|
|
assert isinstance(response, ChatCompletionResponse)
|
|
|
|
message = response.completion_message
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
|
# assert message.stop_reason == stop_reason
|
|
assert message.tool_calls is not None
|
|
assert len(message.tool_calls) > 0
|
|
|
|
call = message.tool_calls[0]
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_with_tool_calling_streaming(
|
|
self,
|
|
inference_model,
|
|
inference_stack,
|
|
common_params,
|
|
sample_messages,
|
|
sample_tool_definition,
|
|
):
|
|
inference_impl, _ = inference_stack
|
|
messages = sample_messages + [
|
|
UserMessage(
|
|
content="What's the weather like in San Francisco?",
|
|
)
|
|
]
|
|
|
|
response = [
|
|
r
|
|
async for r in await inference_impl.chat_completion(
|
|
model_id=inference_model,
|
|
messages=messages,
|
|
tools=[sample_tool_definition],
|
|
stream=True,
|
|
**common_params,
|
|
)
|
|
]
|
|
|
|
assert len(response) > 0
|
|
assert all(
|
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
|
)
|
|
grouped = group_chunks(response)
|
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
|
|
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
|
# expected_stop_reason = get_expected_stop_reason(
|
|
# inference_settings["common_params"]["model"]
|
|
# )
|
|
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
# assert end.event.stop_reason == expected_stop_reason
|
|
|
|
if "Llama3.1" in inference_model:
|
|
assert all(
|
|
isinstance(chunk.event.delta, ToolCallDelta)
|
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
|
)
|
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
|
if not isinstance(
|
|
first.event.delta.content, ToolCall
|
|
): # first chunk may contain entire call
|
|
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
|
|
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
|
# assert last.event.stop_reason == expected_stop_reason
|
|
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
|
assert isinstance(last.event.delta.content, ToolCall)
|
|
|
|
call = last.event.delta.content
|
|
assert call.tool_name == "get_weather"
|
|
assert "location" in call.arguments
|
|
assert "San Francisco" in call.arguments["location"]
|