test(models): convert decode_assistant_message test from unittest to pytest

- Convert test classes to pytest functions and fixtures
- Replace unittest assertions with pytest assertions
- Use pytest.mark.parametrize for parameterized tests
- Remove unittest.TestCase inheritance and setUp methods
- Maintain all test functionality and coverage

Addresses review feedback from PR #2663
This commit is contained in:
skamenan7 2025-07-09 08:57:48 -04:00
parent 857496ea3e
commit 126d6698a7

View file

@ -4,214 +4,212 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from llama_stack.apis.inference import StopReason import pytest
from llama_stack.apis.inference import StopReason, ToolCall
from llama_stack.models.llama.datatypes import RawMessage from llama_stack.models.llama.datatypes import RawMessage
from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
class TestDecodeAssistantMessage(unittest.TestCase): @pytest.fixture
"""Test the decode_assistant_message function with different model types and formats.""" def test_content():
"""Test content fixtures."""
return {
"llama3_content": """I'll help you get the weather information.
def setUp(self): {"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}""",
"""Set up test fixtures.""" "llama4_content": """I'll help you get the weather information.
self.llama3_content = """I'll help you get the weather information.
{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}""" [get_weather(location="San Francisco, CA")]""",
"simple_content": "Hello! How can I help you today?",
}
self.llama4_content = """I'll help you get the weather information.
[get_weather(location="San Francisco, CA")]""" def test_decode_with_no_model_id_defaults_to_llama3(test_content):
"""Test that decode_assistant_message defaults to Llama3 format when no model_id is provided."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
self.simple_content = "Hello! How can I help you today?" result = decode_assistant_message(test_content["simple_content"], StopReason.end_of_turn)
def test_decode_with_no_model_id_defaults_to_llama3(self): mock_chat_format.assert_called_once()
"""Test that decode_assistant_message defaults to Llama3 format when no model_id is provided.""" mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: test_content["simple_content"], StopReason.end_of_turn
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn)
mock_chat_format.assert_called_once()
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
self.simple_content, StopReason.end_of_turn
)
self.assertEqual(result, expected_message)
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_nonexistent_model_uses_llama3(self):
"""Test that decode_assistant_message uses Llama3 format for non-existent models."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = None
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn, "nonexistent-model")
mock_resolve.assert_called_once_with("nonexistent-model")
mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message)
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama3_model_uses_llama3_format(self):
"""Test that decode_assistant_message uses Llama3 format for Llama3 models."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama3
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = mock_model
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama3_content)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
self.llama3_content, StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct"
)
mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct")
mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message)
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_llama4_model_uses_llama4_format(self):
"""Test that decode_assistant_message uses Llama4 format for Llama4 models when available."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama4
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = mock_model
# Mock the Llama4 components
with patch(
"llama_stack.providers.utils.inference.prompt_adapter.Llama4ChatFormat", create=True
) as mock_llama4_format:
with patch(
"llama_stack.providers.utils.inference.prompt_adapter.Llama4Tokenizer", create=True
) as mock_llama4_tokenizer:
mock_tokenizer_instance = Mock()
mock_llama4_tokenizer.get_instance.return_value = mock_tokenizer_instance
mock_formatter = Mock()
mock_llama4_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama4_content)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
)
mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct")
mock_llama4_format.assert_called_once_with(mock_tokenizer_instance)
self.assertEqual(result, expected_message)
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False)
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(self):
"""Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama4
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = mock_model
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.llama4_content)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
)
# Should NOT call resolve_model since LLAMA4_AVAILABLE is False
mock_resolve.assert_not_called()
mock_chat_format.assert_called_once()
self.assertEqual(result, expected_message)
def test_decode_with_different_stop_reasons(self):
"""Test that decode_assistant_message handles different stop reasons correctly."""
stop_reasons = [
StopReason.end_of_turn,
StopReason.end_of_message,
StopReason.out_of_tokens,
]
for stop_reason in stop_reasons:
with self.subTest(stop_reason=stop_reason):
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=self.simple_content)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.simple_content, stop_reason)
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
self.simple_content, stop_reason
)
self.assertEqual(result, expected_message)
def test_decode_preserves_formatter_response(self):
"""Test that decode_assistant_message preserves the formatter's response including tool calls."""
from llama_stack.apis.inference import ToolCall
mock_tool_call = ToolCall(
tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id"
) )
assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
def test_decode_with_nonexistent_model_uses_llama3(test_content):
"""Test that decode_assistant_message uses Llama3 format for non-existent models."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = None
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock() mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter mock_chat_format.return_value = mock_formatter
expected_message = RawMessage( expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
role="assistant", content="I'll help you get the weather.", tool_calls=[mock_tool_call]
)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(self.llama3_content, StopReason.end_of_turn) result = decode_assistant_message(
test_content["simple_content"], StopReason.end_of_turn, "nonexistent-model"
)
self.assertEqual(result, expected_message) mock_resolve.assert_called_once_with("nonexistent-model")
self.assertEqual(len(result.tool_calls), 1) mock_chat_format.assert_called_once()
self.assertEqual(result.tool_calls[0].tool_name, "get_weather") assert result == expected_message
class TestDecodeAssistantMessageIntegration(unittest.TestCase): @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
"""Integration tests for decode_assistant_message with real model resolution.""" def test_decode_with_llama3_model_uses_llama3_format(test_content):
"""Test that decode_assistant_message uses Llama3 format for Llama3 models."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama3
def test_model_resolution_integration(self): with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
"""Test that model resolution works correctly with actual model IDs.""" mock_resolve.return_value = mock_model
# Test with actual model IDs that should resolve
test_cases = [
("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"),
("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"),
("invalid-model-id", "should fallback to Llama3"),
]
for model_id, description in test_cases: with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
with self.subTest(model_id=model_id, description=description): mock_formatter = Mock()
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: mock_chat_format.return_value = mock_formatter
mock_formatter = Mock() expected_message = RawMessage(role="assistant", content=test_content["llama3_content"])
mock_chat_format.return_value = mock_formatter mock_formatter.decode_assistant_message_from_content.return_value = expected_message
expected_message = RawMessage(role="assistant", content="Test content")
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
# This should not raise an exception result = decode_assistant_message(
result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id) test_content["llama3_content"], StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct"
)
self.assertEqual(result, expected_message) mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct")
mock_chat_format.assert_called_once()
assert result == expected_message
if __name__ == "__main__": @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
unittest.main() def test_decode_with_llama4_model_uses_llama4_format(test_content):
"""Test that decode_assistant_message uses Llama4 format for Llama4 models when available."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama4
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = mock_model
# Mock the Llama4 components
with patch(
"llama_stack.providers.utils.inference.prompt_adapter.Llama4ChatFormat", create=True
) as mock_llama4_format:
with patch(
"llama_stack.providers.utils.inference.prompt_adapter.Llama4Tokenizer", create=True
) as mock_llama4_tokenizer:
mock_tokenizer_instance = Mock()
mock_llama4_tokenizer.get_instance.return_value = mock_tokenizer_instance
mock_formatter = Mock()
mock_llama4_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
)
mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct")
mock_llama4_format.assert_called_once_with(mock_tokenizer_instance)
assert result == expected_message
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False)
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(test_content):
"""Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable."""
mock_model = Mock()
mock_model.model_family = ModelFamily.llama4
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
mock_resolve.return_value = mock_model
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(
test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
)
# Should NOT call resolve_model since LLAMA4_AVAILABLE is False
mock_resolve.assert_not_called()
mock_chat_format.assert_called_once()
assert result == expected_message
@pytest.mark.parametrize(
"stop_reason",
[
StopReason.end_of_turn,
StopReason.end_of_message,
StopReason.out_of_tokens,
],
)
def test_decode_with_different_stop_reasons(test_content, stop_reason):
"""Test that decode_assistant_message handles different stop reasons correctly."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(test_content["simple_content"], stop_reason)
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
test_content["simple_content"], stop_reason
)
assert result == expected_message
def test_decode_preserves_formatter_response(test_content):
"""Test that decode_assistant_message preserves the formatter's response including tool calls."""
mock_tool_call = ToolCall(
tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id"
)
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(
role="assistant", content="I'll help you get the weather.", tool_calls=[mock_tool_call]
)
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
result = decode_assistant_message(test_content["llama3_content"], StopReason.end_of_turn)
assert result == expected_message
assert len(result.tool_calls) == 1
assert result.tool_calls[0].tool_name == "get_weather"
@pytest.mark.parametrize(
"model_id,description",
[
("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"),
("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"),
("invalid-model-id", "should fallback to Llama3"),
],
)
def test_model_resolution_integration(model_id, description):
"""Test that model resolution works correctly with actual model IDs."""
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
mock_formatter = Mock()
mock_chat_format.return_value = mock_formatter
expected_message = RawMessage(role="assistant", content="Test content")
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
# This should not raise an exception
result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id)
assert result == expected_message