chore(test): migrate unit tests from unittest to pytest for prompt adapter (#2788)

This PR replaces unittest with pytest.

Part of https://github.com/meta-llama/llama-stack/issues/2680

cc @leseb

Co-authored-by: ehhuang <ehhuang@users.noreply.github.com>
This commit is contained in:
Mustafa Elbehery 2025-07-18 01:31:38 +02:00 committed by GitHub
parent 3ae4aeb344
commit bd8a3ae3cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import unittest
import pytest
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
SystemMessageBehavior,
ToolCall,
ToolConfig,
UserMessage,
@ -25,17 +25,15 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
interleaved_content_as_str,
)
MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
asyncio.get_running_loop().set_debug(False)
async def test_system_default(self):
@pytest.mark.asyncio
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -44,11 +42,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_builtin_only(self):
@pytest.mark.asyncio
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -61,12 +61,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only(self):
@pytest.mark.asyncio
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -89,13 +91,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_system_custom_and_builtin(self):
@pytest.mark.asyncio
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -119,15 +123,17 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
assert len(messages) == 3
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_completion_message_encoding(self):
@pytest.mark.asyncio
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
@ -160,17 +166,16 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
assert '[custom1(param1="value1")]' in prompt
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message(self):
@pytest.mark.asyncio
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
@ -184,12 +189,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
self.assertEqual(messages[-1].content, content)
assert messages[-1].content == content
async def test_repalce_system_message_behavior_builtin_tools(self):
@pytest.mark.asyncio
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
@ -203,17 +210,19 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_repalce_system_message_behavior_custom_tools(self):
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
@ -238,18 +247,20 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template(self):
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
@ -274,15 +285,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertIn("Environment: ipython", messages[0].content)
self.assertIn("You are a pirate", messages[0].content)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
self.assertIn('"name": "custom1"', messages[0].content)
self.assertEqual(messages[-1].content, content)
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content