chore(test): migrate unit tests from unittest to pytest for prompt adapter (#2788)

This PR replaces unittest with pytest.

Part of https://github.com/meta-llama/llama-stack/issues/2680

cc @leseb

Co-authored-by: ehhuang <ehhuang@users.noreply.github.com>
This commit is contained in:
Mustafa Elbehery 2025-07-18 01:31:38 +02:00 committed by GitHub
parent 3ae4aeb344
commit bd8a3ae3cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import pytest
import unittest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionMessage, CompletionMessage,
StopReason, StopReason,
SystemMessage, SystemMessage,
SystemMessageBehavior,
ToolCall, ToolCall,
ToolConfig, ToolConfig,
UserMessage, UserMessage,
@ -25,17 +25,15 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
interleaved_content_as_str,
) )
MODEL = "Llama3.1-8B-Instruct" MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): @pytest.mark.asyncio
async def asyncSetUp(self): async def test_system_default():
asyncio.get_running_loop().set_debug(False)
async def test_system_default(self):
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -44,11 +42,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2) assert len(messages) == 2
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_builtin_only(self):
@pytest.mark.asyncio
async def test_system_builtin_only():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -61,12 +61,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2) assert len(messages) == 2
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content) assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only(self):
@pytest.mark.asyncio
async def test_system_custom_only():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -89,13 +91,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3) assert len(messages) == 3
self.assertTrue("Environment: ipython" in messages[0].content) assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content) assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_system_custom_and_builtin(self):
@pytest.mark.asyncio
async def test_system_custom_and_builtin():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -119,15 +123,17 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3) assert len(messages) == 3
self.assertTrue("Environment: ipython" in messages[0].content) assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content) assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content) assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_completion_message_encoding(self):
@pytest.mark.asyncio
async def test_completion_message_encoding():
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL3_2, model=MODEL3_2,
messages=[ messages=[
@ -160,17 +166,16 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
) )
prompt = await chat_completion_request_to_prompt(request, request.model) prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt) assert '[custom1(param1="value1")]' in prompt
request.model = MODEL request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model) prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn( assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
@pytest.mark.asyncio
async def test_user_provided_system_message():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -184,12 +189,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertTrue(messages[0].content.endswith(system_prompt)) assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_repalce_system_message_behavior_builtin_tools(self):
@pytest.mark.asyncio
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -203,17 +210,19 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
tool_config=ToolConfig( tool_config=ToolConfig(
tool_choice="auto", tool_choice="auto",
tool_prompt_format="python_list", tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior="replace", system_message_behavior=SystemMessageBehavior.replace,
), ),
) )
messages = chat_completion_request_to_messages(request, MODEL3_2) messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertTrue(messages[0].content.endswith(system_prompt)) assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
self.assertIn("Environment: ipython", messages[0].content) assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_repalce_system_message_behavior_custom_tools(self):
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -238,18 +247,20 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
tool_config=ToolConfig( tool_config=ToolConfig(
tool_choice="auto", tool_choice="auto",
tool_prompt_format="python_list", tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior="replace", system_message_behavior=SystemMessageBehavior.replace,
), ),
) )
messages = chat_completion_request_to_messages(request, MODEL3_2) messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertTrue(messages[0].content.endswith(system_prompt)) assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
self.assertIn("Environment: ipython", messages[0].content) assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template(self):
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}" system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -274,15 +285,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
tool_config=ToolConfig( tool_config=ToolConfig(
tool_choice="auto", tool_choice="auto",
tool_prompt_format="python_list", tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior="replace", system_message_behavior=SystemMessageBehavior.replace,
), ),
) )
messages = chat_completion_request_to_messages(request, MODEL3_2) messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertIn("Environment: ipython", messages[0].content) assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
self.assertIn("You are a pirate", messages[0].content) assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt # function description is present in the system prompt
self.assertIn('"name": "custom1"', messages[0].content) assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content