llama-stack-mirror/tests/unit/models/test_prompt_adapter.py
Ashwin Bharambe ef0736527d
feat(tools)!: substantial clean up of "Tool" related datatypes (#3627)
This is a sweeping change to clean up some gunk around our "Tool"
definitions.

First, we had two types `Tool` and `ToolDef`. The first of these was a
"Resource" type for the registry but we had stopped registering tools
inside the Registry long back (and only registered ToolGroups.) The
latter was for specifying tools for the Agents API. This PR removes the
former and adds an optional `toolgroup_id` field to the latter.

Secondly, as pointed out by @bbrowning in
https://github.com/llamastack/llama-stack/pull/3003#issuecomment-3245270132,
we were doing a lossy conversion from a full JSON schema from the MCP
tool specification into our ToolDefinition to send it to the model.
There is no necessity to do this -- we ourselves aren't doing any
execution at all but merely passing it to the chat completions API which
supports this. By doing this (and by doing it poorly), we encountered
limitations like not supporting array items, or not resolving $refs,
etc.

To fix this, we replaced the `parameters` field by `{ input_schema,
output_schema }` which can be full blown JSON schemas.

Finally, there were some types in our llama-related chat format
conversion which needed some cleanup. We are taking this opportunity to
clean those up.

This PR is a substantial breaking change to the API. However, given our
window for introducing breaking changes, this suits us just fine. I will
be landing a concurrent `llama-stack-client` change as well since API
shapes are changing.
2025-10-02 15:12:03 -07:00

303 lines
10 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
SystemMessageBehavior,
ToolCall,
ToolConfig,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
interleaved_content_as_str,
)
MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments='{"param1": "value1"}', # arguments must be a JSON string
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
request.model = MODEL
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert messages[-1].content == content
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content