mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
This is a sweeping change to clean up some gunk around our "Tool" definitions. First, we had two types `Tool` and `ToolDef`. The first of these was a "Resource" type for the registry but we had stopped registering tools inside the Registry long back (and only registered ToolGroups.) The latter was for specifying tools for the Agents API. This PR removes the former and adds an optional `toolgroup_id` field to the latter. Secondly, as pointed out by @bbrowning in https://github.com/llamastack/llama-stack/pull/3003#issuecomment-3245270132, we were doing a lossy conversion from a full JSON schema from the MCP tool specification into our ToolDefinition to send it to the model. There is no necessity to do this -- we ourselves aren't doing any execution at all but merely passing it to the chat completions API which supports this. By doing this (and by doing it poorly), we encountered limitations like not supporting array items, or not resolving $refs, etc. To fix this, we replaced the `parameters` field by `{ input_schema, output_schema }` which can be full blown JSON schemas. Finally, there were some types in our llama-related chat format conversion which needed some cleanup. We are taking this opportunity to clean those up. This PR is a substantial breaking change to the API. However, given our window for introducing breaking changes, this suits us just fine. I will be landing a concurrent `llama-stack-client` change as well since API shapes are changing.
303 lines
10 KiB
Python
303 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
CompletionMessage,
|
|
StopReason,
|
|
SystemMessage,
|
|
SystemMessageBehavior,
|
|
ToolCall,
|
|
ToolConfig,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.models.llama.datatypes import (
|
|
BuiltinTool,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_messages,
|
|
chat_completion_request_to_prompt,
|
|
interleaved_content_as_str,
|
|
)
|
|
|
|
MODEL = "Llama3.1-8B-Instruct"
|
|
MODEL3_2 = "Llama3.2-3B-Instruct"
|
|
|
|
|
|
async def test_system_default():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 2
|
|
assert messages[-1].content == content
|
|
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
|
|
|
|
|
async def test_system_builtin_only():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 2
|
|
assert messages[-1].content == content
|
|
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
|
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
|
|
|
|
|
async def test_system_custom_only():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"param1": {
|
|
"type": "str",
|
|
"description": "param1 description",
|
|
},
|
|
},
|
|
"required": ["param1"],
|
|
},
|
|
)
|
|
],
|
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 3
|
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
|
|
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
|
assert messages[-1].content == content
|
|
|
|
|
|
async def test_system_custom_and_builtin():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"param1": {
|
|
"type": "str",
|
|
"description": "param1 description",
|
|
},
|
|
},
|
|
"required": ["param1"],
|
|
},
|
|
),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 3
|
|
|
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
|
|
|
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
|
assert messages[-1].content == content
|
|
|
|
|
|
async def test_completion_message_encoding():
|
|
request = ChatCompletionRequest(
|
|
model=MODEL3_2,
|
|
messages=[
|
|
UserMessage(content="hello"),
|
|
CompletionMessage(
|
|
content="",
|
|
stop_reason=StopReason.end_of_turn,
|
|
tool_calls=[
|
|
ToolCall(
|
|
tool_name="custom1",
|
|
arguments='{"param1": "value1"}', # arguments must be a JSON string
|
|
call_id="123",
|
|
)
|
|
],
|
|
),
|
|
],
|
|
tools=[
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"param1": {
|
|
"type": "str",
|
|
"description": "param1 description",
|
|
},
|
|
},
|
|
"required": ["param1"],
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
|
)
|
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
assert '[custom1(param1="value1")]' in prompt
|
|
|
|
request.model = MODEL
|
|
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
|
|
|
|
|
async def test_user_provided_system_message():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 2
|
|
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
|
|
assert messages[-1].content == content
|
|
|
|
|
|
async def test_replace_system_message_behavior_builtin_tools():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format=ToolPromptFormat.python_list,
|
|
system_message_behavior=SystemMessageBehavior.replace,
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
assert len(messages) == 2
|
|
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
assert messages[-1].content == content
|
|
|
|
|
|
async def test_replace_system_message_behavior_custom_tools():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"param1": {
|
|
"type": "str",
|
|
"description": "param1 description",
|
|
},
|
|
},
|
|
"required": ["param1"],
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format=ToolPromptFormat.python_list,
|
|
system_message_behavior=SystemMessageBehavior.replace,
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
|
|
assert len(messages) == 2
|
|
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
assert messages[-1].content == content
|
|
|
|
|
|
async def test_replace_system_message_behavior_custom_tools_with_template():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate {{ function_description }}"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"param1": {
|
|
"type": "str",
|
|
"description": "param1 description",
|
|
},
|
|
},
|
|
"required": ["param1"],
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format=ToolPromptFormat.python_list,
|
|
system_message_behavior=SystemMessageBehavior.replace,
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
|
|
assert len(messages) == 2
|
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
|
# function description is present in the system prompt
|
|
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
|
assert messages[-1].content == content
|