mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 07:13:53 +00:00
# What does this PR do?
This PR proposes updates to the tools API in Inference and Agent.
Goals:
1. Agent's tool specification should be consistent with Inference's tool spec, but with add-ons.
2. Formal types should be defined for built in tools. Currently Agent tools args are untyped, e.g. how does one know that `builtin::rag_tool` takes a `vector_db_ids` param or even how to know 'builtin::rag_tool' is even available (in code, outside of docs)?
Inference:
1. BuiltinTool is to be removed and replaced by a formal `type` parameter.
2. 'brave_search' is replaced by 'web_search' to be more generic. It will still be translated back to brave_search when the prompt is constructed to be consistent with model training.
3. I'm not sure what `photogen` is. Maybe it can be removed?
Agent:
1. Uses the same format as in Inference for builtin tools.
2. New tools types are added, i.e. knowledge_sesarch (currently rag_tool), and MCP tool.
3. Toolgroup as a concept will be removed since it's really only used for MCP.
4. Instead MCPTool is its own type and available tools provided by the server will be expanded by default. Users can specify a subset of tool names if desired.
Example snippet:
```
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant. Use the tools you have access to for providing relevant answers.",
tools=[
KnowledgeSearchTool(vector_store_id="1234"),
KnowledgeSearchTool(vector_store_id="5678", name="paper_search", description="Search research papers"),
KnowledgeSearchTool(vector_store_id="1357", name="wiki_search", description="Search wiki pages"),
# no need to register toolgroup, just pass in the server uri
# all available tools will be used
MCPTool(server_uri="http://localhost:8000/sse"),
# can specify a subset of available tools
MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
# custom tool
my_custom_tool,
]
)
```
## Test Plan
# What does this PR do?
## Test Plan
# What does this PR do?
## Test Plan
301 lines
9.1 KiB
Python
301 lines
9.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
CompletionMessage,
|
|
StopReason,
|
|
SystemMessage,
|
|
ToolCall,
|
|
ToolConfig,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.models.llama.datatypes import (
|
|
CodeInterpreterTool,
|
|
FunctionTool,
|
|
ToolParamDefinition,
|
|
ToolPromptFormat,
|
|
WebSearchTool,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_messages,
|
|
chat_completion_request_to_prompt,
|
|
)
|
|
|
|
MODEL = "Llama3.1-8B-Instruct"
|
|
MODEL3_2 = "Llama3.2-3B-Instruct"
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup_loop():
|
|
loop = asyncio.get_event_loop()
|
|
loop.set_debug(False)
|
|
return loop
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_default():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 2
|
|
assert messages[-1].content == content
|
|
assert "Cutting Knowledge Date: December 2023" in messages[0].content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_builtin_only():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
CodeInterpreterTool(),
|
|
WebSearchTool(),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 2
|
|
assert messages[-1].content == content
|
|
assert "Cutting Knowledge Date: December 2023" in messages[0].content
|
|
assert "Tools: brave_search" in messages[0].content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_custom_only():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
FunctionTool(
|
|
name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
)
|
|
],
|
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 3
|
|
assert "Environment: ipython" in messages[0].content
|
|
assert "Return function calls in JSON format" in messages[1].content
|
|
assert messages[-1].content == content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_custom_and_builtin():
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
CodeInterpreterTool(),
|
|
WebSearchTool(),
|
|
FunctionTool(
|
|
name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 3
|
|
assert "Environment: ipython" in messages[0].content
|
|
assert "Tools: brave_search" in messages[0].content
|
|
assert "Return function calls in JSON format" in messages[1].content
|
|
assert messages[-1].content == content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_message_encoding():
|
|
request = ChatCompletionRequest(
|
|
model=MODEL3_2,
|
|
messages=[
|
|
UserMessage(content="hello"),
|
|
CompletionMessage(
|
|
content="",
|
|
stop_reason=StopReason.end_of_turn,
|
|
tool_calls=[
|
|
ToolCall(
|
|
type="function",
|
|
tool_name="custom1",
|
|
arguments={"param1": "value1"},
|
|
call_id="123",
|
|
)
|
|
],
|
|
),
|
|
],
|
|
tools=[
|
|
FunctionTool(
|
|
name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
|
)
|
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
assert '[custom1(param1="value1")]' in prompt
|
|
|
|
request.model = MODEL
|
|
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_provided_system_message():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
CodeInterpreterTool(),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
assert len(messages) == 2
|
|
assert messages[0].content.endswith(system_prompt)
|
|
assert messages[-1].content == content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repalce_system_message_behavior_builtin_tools():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
CodeInterpreterTool(),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format="python_list",
|
|
system_message_behavior="replace",
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
assert len(messages) == 2
|
|
assert messages[0].content.endswith(system_prompt)
|
|
assert "Environment: ipython" in messages[0].content
|
|
assert messages[-1].content == content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repalce_system_message_behavior_custom_tools():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
CodeInterpreterTool(),
|
|
FunctionTool(
|
|
name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format="python_list",
|
|
system_message_behavior="replace",
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
assert len(messages) == 2
|
|
assert messages[0].content.endswith(system_prompt)
|
|
assert "Environment: ipython" in messages[0].content
|
|
assert messages[-1].content == content
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_replace_system_message_behavior_custom_tools_with_template():
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate {{ function_description }}"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
CodeInterpreterTool(),
|
|
FunctionTool(
|
|
name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format="python_list",
|
|
system_message_behavior="replace",
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
assert len(messages) == 2
|
|
assert "Environment: ipython" in messages[0].content
|
|
assert "You are a pirate" in messages[0].content
|
|
assert '"name": "custom1"' in messages[0].content
|
|
assert messages[-1].content == content
|