mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
# What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan python -m unittest llama_stack.providers.tests.inference.test_prompt_adapter ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/937). * #938 * __->__ #937
238 lines
8.7 KiB
Python
238 lines
8.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
from llama_models.llama3.api.datatypes import (
|
|
BuiltinTool,
|
|
ToolDefinition,
|
|
ToolParamDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
SystemMessage,
|
|
ToolConfig,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_messages,
|
|
)
|
|
|
|
MODEL = "Llama3.1-8B-Instruct"
|
|
MODEL3_2 = "Llama3.2-3B-Instruct"
|
|
|
|
|
|
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|
async def test_system_default(self):
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
self.assertEqual(len(messages), 2)
|
|
self.assertEqual(messages[-1].content, content)
|
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
|
|
|
async def test_system_builtin_only(self):
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
self.assertEqual(len(messages), 2)
|
|
self.assertEqual(messages[-1].content, content)
|
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
|
self.assertTrue("Tools: brave_search" in messages[0].content)
|
|
|
|
async def test_system_custom_only(self):
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
)
|
|
],
|
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
self.assertEqual(len(messages), 3)
|
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
|
|
|
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
|
self.assertEqual(messages[-1].content, content)
|
|
|
|
async def test_system_custom_and_builtin(self):
|
|
content = "Hello !"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
self.assertEqual(len(messages), 3)
|
|
|
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
|
self.assertTrue("Tools: brave_search" in messages[0].content)
|
|
|
|
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
|
self.assertEqual(messages[-1].content, content)
|
|
|
|
async def test_user_provided_system_message(self):
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
],
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL)
|
|
self.assertEqual(len(messages), 2, messages)
|
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
|
|
|
self.assertEqual(messages[-1].content, content)
|
|
|
|
async def test_repalce_system_message_behavior_builtin_tools(self):
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format="python_list",
|
|
system_message_behavior="replace",
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
self.assertEqual(len(messages), 2, messages)
|
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
|
self.assertIn("Environment: ipython", messages[0].content)
|
|
self.assertEqual(messages[-1].content, content)
|
|
|
|
async def test_repalce_system_message_behavior_custom_tools(self):
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format="python_list",
|
|
system_message_behavior="replace",
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
|
|
self.assertEqual(len(messages), 2, messages)
|
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
|
self.assertIn("Environment: ipython", messages[0].content)
|
|
self.assertEqual(messages[-1].content, content)
|
|
|
|
async def test_replace_system_message_behavior_custom_tools_with_template(self):
|
|
content = "Hello !"
|
|
system_prompt = "You are a pirate {{ function_description }}"
|
|
request = ChatCompletionRequest(
|
|
model=MODEL,
|
|
messages=[
|
|
SystemMessage(content=system_prompt),
|
|
UserMessage(content=content),
|
|
],
|
|
tools=[
|
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
ToolDefinition(
|
|
tool_name="custom1",
|
|
description="custom1 tool",
|
|
parameters={
|
|
"param1": ToolParamDefinition(
|
|
param_type="str",
|
|
description="param1 description",
|
|
required=True,
|
|
),
|
|
},
|
|
),
|
|
],
|
|
tool_config=ToolConfig(
|
|
tool_choice="auto",
|
|
tool_prompt_format="python_list",
|
|
system_message_behavior="replace",
|
|
),
|
|
)
|
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
|
|
self.assertEqual(len(messages), 2, messages)
|
|
self.assertIn("Environment: ipython", messages[0].content)
|
|
self.assertIn("You are a pirate", messages[0].content)
|
|
# function description is present in the system prompt
|
|
self.assertIn('"name": "custom1"', messages[0].content)
|
|
self.assertEqual(messages[-1].content, content)
|