test(models): convert decode_assistant_message test from unittest to pytest

- Convert test classes to pytest functions and fixtures
- Replace unittest assertions with pytest assertions
- Use pytest.mark.parametrize for parameterized tests
- Remove unittest.TestCase inheritance and setUp methods
- Maintain all test functionality and coverage

Addresses review feedback from PR #2663
This commit is contained in:
skamenan7 2025-07-09 08:57:48 -04:00
parent c37d831911
commit bdf251b870

View file

@ -4,48 +4,49 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from llama_stack.apis.inference import StopReason import pytest
from llama_stack.apis.inference import StopReason, ToolCall
from llama_stack.models.llama.datatypes import RawMessage from llama_stack.models.llama.datatypes import RawMessage
from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
class TestDecodeAssistantMessage(unittest.TestCase): @pytest.fixture
"""Test the decode_assistant_message function with different model types and formats.""" def test_content():
"""Test content fixtures."""
return {
"llama3_content": """I'll help you get the weather information.
def setUp(self): {"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}""",
"""Set up test fixtures.""" "llama4_content": """I'll help you get the weather information.
self.llama3_content = """I'll help you get the weather information.
{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}""" [get_weather(location="San Francisco, CA")]""",
"simple_content": "Hello! How can I help you today?",
}
self.llama4_content = """I'll help you get the weather information.
[get_weather(location="San Francisco, CA")]""" def test_decode_with_no_model_id_defaults_to_llama3(test_content):
self.simple_content = "Hello! How can I help you today?"
def test_decode_with_no_model_id_defaults_to_llama3(self):
"""Test that decode_assistant_message defaults to Llama3 format when no model_id is provided.""" """Test that decode_assistant_message defaults to Llama3 format when no model_id is provided."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock() mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content) expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn) result = decode_assistant_message(test_content["simple_content"], StopReason.end_of_turn)
mock_chat_format.assert_called_once() mock_chat_format.assert_called_once()
mock_formatter.decode_assistant_message_from_content.assert_called_once_with( mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
self.simple_content, StopReason.end_of_turn test_content["simple_content"], StopReason.end_of_turn
) )
self.assertEqual(result, expected_message) assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_nonexistent_model_uses_llama3(self): @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_nonexistent_model_uses_llama3(test_content):
"""Test that decode_assistant_message uses Llama3 format for non-existent models.""" """Test that decode_assistant_message uses Llama3 format for non-existent models."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve: with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = None mock_resolve.return_value = None
@ -53,17 +54,20 @@ class TestDecodeAssistantMessage(unittest.TestCase):
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock() mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content) expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn, "nonexistent-model") result = decode_assistant_message(
test_content["simple_content"], StopReason.end_of_turn, "nonexistent-model"
)
mock_resolve.assert_called_once_with("nonexistent-model") mock_resolve.assert_called_once_with("nonexistent-model")
mock_chat_format.assert_called_once() mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message) assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama3_model_uses_llama3_format(self): @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama3_model_uses_llama3_format(test_content):
"""Test that decode_assistant_message uses Llama3 format for Llama3 models.""" """Test that decode_assistant_message uses Llama3 format for Llama3 models."""
mock_model = Mock() mock_model = Mock()
mock_model.model_family = ModelFamily.llama3 mock_model.model_family = ModelFamily.llama3
@ -74,19 +78,20 @@ class TestDecodeAssistantMessage(unittest.TestCase):
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock() mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama3_content) expected_message = RawMessage(role="assistant", content=test_content["llama3_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message( result = decode_assistant_message(
self.llama3_content, StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct" test_content["llama3_content"], StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct"
) )
mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct") mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct")
mock_chat_format.assert_called_once() mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message) assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama4_model_uses_llama4_format(self): @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama4_model_uses_llama4_format(test_content):
"""Test that decode_assistant_message uses Llama4 format for Llama4 models when available.""" """Test that decode_assistant_message uses Llama4 format for Llama4 models when available."""
mock_model = Mock() mock_model = Mock()
mock_model.model_family = ModelFamily.llama4 mock_model.model_family = ModelFamily.llama4
@ -106,19 +111,20 @@ class TestDecodeAssistantMessage(unittest.TestCase):
mock_formatter = Mock() mock_formatter = Mock()
mock_llama4_format.return_value = mock_formatter mock_llama4_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama4_content) expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message( result = decode_assistant_message(
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct" test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
) )
mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct") mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct")
mock_llama4_format.assert_called_once_with(mock_tokenizer_instance) mock_llama4_format.assert_called_once_with(mock_tokenizer_instance)
self.assertEqual(result, expected_message) assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False)
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(self): @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False)
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(test_content):
"""Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable.""" """Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable."""
mock_model = Mock() mock_model = Mock()
mock_model.model_family = ModelFamily.llama4 mock_model.model_family = ModelFamily.llama4
@ -129,45 +135,45 @@ class TestDecodeAssistantMessage(unittest.TestCase):
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock() mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama4_content) expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message( result = decode_assistant_message(
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct" test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
) )
# Should NOT call resolve_model since LLAMA4_AVAILABLE is False # Should NOT call resolve_model since LLAMA4_AVAILABLE is False
mock_resolve.assert_not_called() mock_resolve.assert_not_called()
mock_chat_format.assert_called_once() mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message) assert result == expected_message
def test_decode_with_different_stop_reasons(self):
"""Test that decode_assistant_message handles different stop reasons correctly.""" @pytest.mark.parametrize(
stop_reasons = [ "stop_reason",
[
StopReason.end_of_turn, StopReason.end_of_turn,
StopReason.end_of_message, StopReason.end_of_message,
StopReason.out_of_tokens, StopReason.out_of_tokens,
] ],
)
for stop_reason in stop_reasons: def test_decode_with_different_stop_reasons(test_content, stop_reason):
with self.subTest(stop_reason=stop_reason): """Test that decode_assistant_message handles different stop reasons correctly."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock() mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content) expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, stop_reason) result = decode_assistant_message(test_content["simple_content"], stop_reason)
mock_formatter.decode_assistant_message_from_content.assert_called_once_with( mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
self.simple_content, stop_reason test_content["simple_content"], stop_reason
) )
self.assertEqual(result, expected_message) assert result == expected_message
def test_decode_preserves_formatter_response(self):
def test_decode_preserves_formatter_response(test_content):
"""Test that decode_assistant_message preserves the formatter's response including tool calls.""" """Test that decode_assistant_message preserves the formatter's response including tool calls."""
from llama_stack.apis.inference import ToolCall
mock_tool_call = ToolCall( mock_tool_call = ToolCall(
tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id" tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id"
) )
@ -180,27 +186,23 @@ class TestDecodeAssistantMessage(unittest.TestCase):
) )
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.llama3_content, StopReason.end_of_turn) result = decode_assistant_message(test_content["llama3_content"], StopReason.end_of_turn)
self.assertEqual(result, expected_message) assert result == expected_message
self.assertEqual(len(result.tool_calls), 1) assert len(result.tool_calls) == 1
self.assertEqual(result.tool_calls[0].tool_name, "get_weather") assert result.tool_calls[0].tool_name == "get_weather"
class TestDecodeAssistantMessageIntegration(unittest.TestCase): @pytest.mark.parametrize(
"""Integration tests for decode_assistant_message with real model resolution.""" "model_id,description",
[
def test_model_resolution_integration(self):
"""Test that model resolution works correctly with actual model IDs."""
# Test with actual model IDs that should resolve
test_cases = [
("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"), ("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"),
("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"), ("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"),
("invalid-model-id", "should fallback to Llama3"), ("invalid-model-id", "should fallback to Llama3"),
] ],
)
for model_id, description in test_cases: def test_model_resolution_integration(model_id, description):
with self.subTest(model_id=model_id, description=description): """Test that model resolution works correctly with actual model IDs."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock() mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter mock_chat_format.return_value = mock_formatter
@ -210,8 +212,4 @@ class TestDecodeAssistantMessageIntegration(unittest.TestCase):
# This should not raise an exception # This should not raise an exception
result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id) result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id)
self.assertEqual(result, expected_message) assert result == expected_message
if __name__ == "__main__":
unittest.main()