chore(test): migrate unit tests from unittest to pytest for prompt adapter (#2788)

This PR replaces unittest with pytest.

Part of https://github.com/meta-llama/llama-stack/issues/2680

cc @leseb

Co-authored-by: ehhuang <ehhuang@users.noreply.github.com>
This commit is contained in:
Mustafa Elbehery 2025-07-18 01:31:38 +02:00 committed by GitHub
parent 3ae4aeb344
commit bd8a3ae3cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import pytest
import unittest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionMessage, CompletionMessage,
StopReason, StopReason,
SystemMessage, SystemMessage,
SystemMessageBehavior,
ToolCall, ToolCall,
ToolConfig, ToolConfig,
UserMessage, UserMessage,
@ -25,264 +25,275 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
interleaved_content_as_str,
) )
MODEL = "Llama3.1-8B-Instruct" MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): @pytest.mark.asyncio
async def asyncSetUp(self): async def test_system_default():
asyncio.get_running_loop().set_debug(False) content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_default(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
async def test_system_builtin_only(self): @pytest.mark.asyncio
content = "Hello !" async def test_system_builtin_only():
request = ChatCompletionRequest( content = "Hello !"
model=MODEL, request = ChatCompletionRequest(
messages=[ model=MODEL,
UserMessage(content=content), messages=[
], UserMessage(content=content),
tools=[ ],
ToolDefinition(tool_name=BuiltinTool.code_interpreter), tools=[
ToolDefinition(tool_name=BuiltinTool.brave_search), ToolDefinition(tool_name=BuiltinTool.code_interpreter),
], ToolDefinition(tool_name=BuiltinTool.brave_search),
) ],
messages = chat_completion_request_to_messages(request, MODEL) )
self.assertEqual(len(messages), 2) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(messages[-1].content, content) assert len(messages) == 2
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) assert messages[-1].content == content
self.assertTrue("Tools: brave_search" in messages[0].content) assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content) @pytest.mark.asyncio
self.assertEqual(messages[-1].content, content) async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_and_builtin(self): assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
content = "Hello !" assert messages[-1].content == content
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content) @pytest.mark.asyncio
self.assertEqual(messages[-1].content, content) async def test_system_custom_and_builtin():
content = "Hello !"
async def test_completion_message_encoding(self): request = ChatCompletionRequest(
request = ChatCompletionRequest( model=MODEL,
model=MODEL3_2, messages=[
messages=[ UserMessage(content=content),
UserMessage(content="hello"), ],
CompletionMessage( tools=[
content="", ToolDefinition(tool_name=BuiltinTool.code_interpreter),
stop_reason=StopReason.end_of_turn, ToolDefinition(tool_name=BuiltinTool.brave_search),
tool_calls=[ ToolDefinition(
ToolCall( tool_name="custom1",
tool_name="custom1", description="custom1 tool",
arguments={"param1": "value1"}, parameters={
call_id="123", "param1": ToolParamDefinition(
) param_type="str",
], description="param1 description",
), required=True,
], ),
tools=[ },
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertEqual(messages[-1].content, content)
async def test_repalce_system_message_behavior_builtin_tools(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) )
self.assertEqual(len(messages), 2, messages) messages = chat_completion_request_to_messages(request, MODEL)
self.assertTrue(messages[0].content.endswith(system_prompt)) assert len(messages) == 3
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
async def test_repalce_system_message_behavior_custom_tools(self): assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
content = "Hello !" assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
system_prompt = "You are a pirate"
request = ChatCompletionRequest( assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
model=MODEL, assert messages[-1].content == content
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content), @pytest.mark.asyncio
], async def test_completion_message_encoding():
tools=[ request = ChatCompletionRequest(
ToolDefinition(tool_name=BuiltinTool.code_interpreter), model=MODEL3_2,
ToolDefinition( messages=[
tool_name="custom1", UserMessage(content="hello"),
description="custom1 tool", CompletionMessage(
parameters={ content="",
"param1": ToolParamDefinition( stop_reason=StopReason.end_of_turn,
param_type="str", tool_calls=[
description="param1 description", ToolCall(
required=True, tool_name="custom1",
), arguments={"param1": "value1"},
}, call_id="123",
), )
], ],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) tools=[
ToolDefinition(
self.assertEqual(len(messages), 2, messages) tool_name="custom1",
self.assertTrue(messages[0].content.endswith(system_prompt)) description="custom1 tool",
self.assertIn("Environment: ipython", messages[0].content) parameters={
self.assertEqual(messages[-1].content, content) "param1": ToolParamDefinition(
param_type="str",
async def test_replace_system_message_behavior_custom_tools_with_template(self): description="param1 description",
content = "Hello !" required=True,
system_prompt = "You are a pirate {{ function_description }}" ),
request = ChatCompletionRequest( },
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
self.assertEqual(len(messages), 2, messages) request.model = MODEL
self.assertIn("Environment: ipython", messages[0].content) request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
self.assertIn("You are a pirate", messages[0].content) prompt = await chat_completion_request_to_prompt(request, request.model)
# function description is present in the system prompt assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
self.assertIn('"name": "custom1"', messages[0].content)
self.assertEqual(messages[-1].content, content)
@pytest.mark.asyncio
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert messages[-1].content == content
@pytest.mark.asyncio
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content