mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
use templates for generating system prompts
This commit is contained in:
parent
68855ed218
commit
ab8193c88c
8 changed files with 410 additions and 218 deletions
|
@ -13,7 +13,10 @@ from llama_models.llama3.api.datatypes import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
|
||||
from llama_toolchain.inference.api.datatypes import (
|
||||
ChatCompletionResponseEventType,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
||||
from llama_toolchain.inference.ollama.config import OllamaImplConfig
|
||||
from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
||||
|
@ -236,6 +239,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
stream=True,
|
||||
tools=[self.custom_tool_defn],
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
iterator = self.api.chat_completion(request)
|
||||
events = []
|
||||
|
|
|
@ -2,12 +2,12 @@ import unittest
|
|||
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages_for_tools
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||
|
||||
MODEL = "Meta-Llama3.1-8B-Instruct"
|
||||
|
||||
|
||||
class ToolUtilsTests(unittest.IsolatedAsyncioTestCase):
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -16,12 +16,10 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase):
|
|||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
request = prepare_messages_for_tools(request)
|
||||
self.assertEqual(len(request.messages), 2)
|
||||
self.assertEqual(request.messages[-1].content, content)
|
||||
self.assertTrue(
|
||||
"Cutting Knowledge Date: December 2023" in request.messages[0].content
|
||||
)
|
||||
messages = prepare_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
||||
async def test_system_builtin_only(self):
|
||||
content = "Hello !"
|
||||
|
@ -35,13 +33,11 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
request = prepare_messages_for_tools(request)
|
||||
self.assertEqual(len(request.messages), 2)
|
||||
self.assertEqual(request.messages[-1].content, content)
|
||||
self.assertTrue(
|
||||
"Cutting Knowledge Date: December 2023" in request.messages[0].content
|
||||
)
|
||||
self.assertTrue("Tools: brave_search" in request.messages[0].content)
|
||||
messages = prepare_messages(request)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
async def test_system_custom_only(self):
|
||||
content = "Hello !"
|
||||
|
@ -65,14 +61,12 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
)
|
||||
request = prepare_messages_for_tools(request)
|
||||
self.assertEqual(len(request.messages), 3)
|
||||
self.assertTrue("Environment: ipython" in request.messages[0].content)
|
||||
messages = prepare_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
self.assertTrue(
|
||||
"Return function calls in JSON format" in request.messages[1].content
|
||||
)
|
||||
self.assertEqual(request.messages[-1].content, content)
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_system_custom_and_builtin(self):
|
||||
content = "Hello !"
|
||||
|
@ -97,16 +91,14 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase):
|
|||
),
|
||||
],
|
||||
)
|
||||
request = prepare_messages_for_tools(request)
|
||||
self.assertEqual(len(request.messages), 3)
|
||||
messages = prepare_messages(request)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in request.messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in request.messages[0].content)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
self.assertTrue(
|
||||
"Return function calls in JSON format" in request.messages[1].content
|
||||
)
|
||||
self.assertEqual(request.messages[-1].content, content)
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
|
@ -121,8 +113,8 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
request = prepare_messages_for_tools(request)
|
||||
self.assertEqual(len(request.messages), 2, request.messages)
|
||||
self.assertTrue(request.messages[0].content.endswith(system_prompt))
|
||||
messages = prepare_messages(request)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
self.assertEqual(request.messages[-1].content, content)
|
||||
self.assertEqual(messages[-1].content, content)
|
101
tests/test_prompt_templates.py
Normal file
101
tests/test_prompt_templates.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from llama_toolchain.common.prompt_templates.system_prompts import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
|
||||
def check_generator_output(self, generator, expected_text):
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
self.assertEqual(text, expected_text)
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
self.check_generator_output(generator, expected_text)
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If a you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
Loading…
Add table
Add a link
Reference in a new issue