test(models): convert decode_assistant_message test from unittest to pytest

- Convert test classes to pytest functions and fixtures
- Replace unittest assertions with pytest assertions
- Use pytest.mark.parametrize for parameterized tests
- Remove unittest.TestCase inheritance and setUp methods
- Maintain all test functionality and coverage

Addresses review feedback from PR #2663
This commit is contained in:
skamenan7 2025-07-09 08:57:48 -04:00
parent 857496ea3e
commit 126d6698a7

View file

@ -4,48 +4,49 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import unittest
from unittest.mock import Mock, patch
from llama_stack.apis.inference import StopReason
import pytest
from llama_stack.apis.inference import StopReason, ToolCall
from llama_stack.models.llama.datatypes import RawMessage
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
class TestDecodeAssistantMessage(unittest.TestCase):
"""Test the decode_assistant_message function with different model types and formats."""
@pytest.fixture
def test_content():
"""Test content fixtures."""
return {
"llama3_content": """I'll help you get the weather information.
def setUp(self):
"""Set up test fixtures."""
self.llama3_content = """I'll help you get the weather information.
{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}""",
"llama4_content": """I'll help you get the weather information.
{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}"""
[get_weather(location="San Francisco, CA")]""",
"simple_content": "Hello! How can I help you today?",
}
self.llama4_content = """I'll help you get the weather information.
[get_weather(location="San Francisco, CA")]"""
self.simple_content = "Hello! How can I help you today?"
def test_decode_with_no_model_id_defaults_to_llama3(self):
def test_decode_with_no_model_id_defaults_to_llama3(test_content):
"""Test that decode_assistant_message defaults to Llama3 format when no model_id is provided."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content)
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn)
result = decode_assistant_message(test_content["simple_content"], StopReason.end_of_turn)
mock_chat_format.assert_called_once()
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
self.simple_content, StopReason.end_of_turn
test_content["simple_content"], StopReason.end_of_turn
)
self.assertEqual(result, expected_message)
assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_nonexistent_model_uses_llama3(self):
def test_decode_with_nonexistent_model_uses_llama3(test_content):
"""Test that decode_assistant_message uses Llama3 format for non-existent models."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = None
@ -53,17 +54,20 @@ class TestDecodeAssistantMessage(unittest.TestCase):
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content)
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn, "nonexistent-model")
result = decode_assistant_message(
test_content["simple_content"], StopReason.end_of_turn, "nonexistent-model"
)
mock_resolve.assert_called_once_with("nonexistent-model")
mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message)
assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama3_model_uses_llama3_format(self):
def test_decode_with_llama3_model_uses_llama3_format(test_content):
"""Test that decode_assistant_message uses Llama3 format for Llama3 models."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama3
@ -74,19 +78,20 @@ class TestDecodeAssistantMessage(unittest.TestCase):
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama3_content)
expected_message = RawMessage(role="assistant", content=test_content["llama3_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
self.llama3_content, StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct"
test_content["llama3_content"], StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct"
)
mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct")
mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message)
assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama4_model_uses_llama4_format(self):
def test_decode_with_llama4_model_uses_llama4_format(test_content):
"""Test that decode_assistant_message uses Llama4 format for Llama4 models when available."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama4
@ -106,19 +111,20 @@ class TestDecodeAssistantMessage(unittest.TestCase):
mock_formatter = Mock()
mock_llama4_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama4_content)
expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
)
mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct")
mock_llama4_format.assert_called_once_with(mock_tokenizer_instance)
self.assertEqual(result, expected_message)
assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False)
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(self):
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(test_content):
"""Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama4
@ -129,45 +135,45 @@ class TestDecodeAssistantMessage(unittest.TestCase):
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama4_content)
expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
)
# Should NOT call resolve_model since LLAMA4_AVAILABLE is False
mock_resolve.assert_not_called()
mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message)
assert result == expected_message
def test_decode_with_different_stop_reasons(self):
"""Test that decode_assistant_message handles different stop reasons correctly."""
stop_reasons = [
@pytest.mark.parametrize(
"stop_reason",
[
StopReason.end_of_turn,
StopReason.end_of_message,
StopReason.out_of_tokens,
]
for stop_reason in stop_reasons:
with self.subTest(stop_reason=stop_reason):
],
)
def test_decode_with_different_stop_reasons(test_content, stop_reason):
"""Test that decode_assistant_message handles different stop reasons correctly."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content)
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, stop_reason)
result = decode_assistant_message(test_content["simple_content"], stop_reason)
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
self.simple_content, stop_reason
test_content["simple_content"], stop_reason
)
self.assertEqual(result, expected_message)
assert result == expected_message
def test_decode_preserves_formatter_response(self):
def test_decode_preserves_formatter_response(test_content):
"""Test that decode_assistant_message preserves the formatter's response including tool calls."""
from llama_stack.apis.inference import ToolCall
mock_tool_call = ToolCall(
tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id"
)
@ -180,27 +186,23 @@ class TestDecodeAssistantMessage(unittest.TestCase):
)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.llama3_content, StopReason.end_of_turn)
result = decode_assistant_message(test_content["llama3_content"], StopReason.end_of_turn)
self.assertEqual(result, expected_message)
self.assertEqual(len(result.tool_calls), 1)
self.assertEqual(result.tool_calls[0].tool_name, "get_weather")
assert result == expected_message
assert len(result.tool_calls) == 1
assert result.tool_calls[0].tool_name == "get_weather"
class TestDecodeAssistantMessageIntegration(unittest.TestCase):
"""Integration tests for decode_assistant_message with real model resolution."""
def test_model_resolution_integration(self):
"""Test that model resolution works correctly with actual model IDs."""
# Test with actual model IDs that should resolve
test_cases = [
@pytest.mark.parametrize(
"model_id,description",
[
("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"),
("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"),
("invalid-model-id", "should fallback to Llama3"),
]
for model_id, description in test_cases:
with self.subTest(model_id=model_id, description=description):
],
)
def test_model_resolution_integration(model_id, description):
"""Test that model resolution works correctly with actual model IDs."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
@ -210,8 +212,4 @@ class TestDecodeAssistantMessageIntegration(unittest.TestCase):
# This should not raise an exception
result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id)
self.assertEqual(result, expected_message)
if __name__ == "__main__":
unittest.main()
assert result == expected_message