mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 19:30:05 +00:00
refactor: move more tests, delete some providers tests (#1382)
Move unittests to tests/unittests. Gradually nuking tests from providers/tests/ and unifying them into tests/api (which are e2e tests using SDK types) ## Test Plan `pytest -s -v tests/unittests/`
This commit is contained in:
parent
e5ec68f66e
commit
86fc514abb
11 changed files with 6 additions and 142 deletions
281
tests/unittests/models/test_prompt_adapter.py
Normal file
281
tests/unittests/models/test_prompt_adapter.py
Normal file
|
@ -0,0 +1,281 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
||||
async def test_system_builtin_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
async def test_system_custom_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_system_custom_and_builtin(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_completion_message_encoding(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('[custom1(param1="value1")]', prompt)
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_repalce_system_message_behavior_builtin_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_repalce_system_message_behavior_custom_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertIn("You are a pirate", messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
self.assertIn('"name": "custom1"', messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
198
tests/unittests/models/test_system_prompts.py
Normal file
198
tests/unittests/models/test_system_prompts.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator, expected_text):
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
self.check_generator_output(generator, expected_text)
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]"""
|
||||
)
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
Loading…
Add table
Add a link
Reference in a new issue