mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
test(models): convert decode_assistant_message test from unittest to pytest
- Convert test classes to pytest functions and fixtures - Replace unittest assertions with pytest assertions - Use pytest.mark.parametrize for parameterized tests - Remove unittest.TestCase inheritance and setUp methods - Maintain all test functionality and coverage Addresses review feedback from PR #2663
This commit is contained in:
parent
857496ea3e
commit
126d6698a7
1 changed files with 180 additions and 182 deletions
|
|
@ -4,214 +4,212 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from llama_stack.apis.inference import StopReason
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import StopReason, ToolCall
|
||||||
from llama_stack.models.llama.datatypes import RawMessage
|
from llama_stack.models.llama.datatypes import RawMessage
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily
|
from llama_stack.models.llama.sku_types import ModelFamily
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
||||||
|
|
||||||
|
|
||||||
class TestDecodeAssistantMessage(unittest.TestCase):
|
@pytest.fixture
|
||||||
"""Test the decode_assistant_message function with different model types and formats."""
|
def test_content():
|
||||||
|
"""Test content fixtures."""
|
||||||
|
return {
|
||||||
|
"llama3_content": """I'll help you get the weather information.
|
||||||
|
|
||||||
def setUp(self):
|
{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}""",
|
||||||
"""Set up test fixtures."""
|
"llama4_content": """I'll help you get the weather information.
|
||||||
self.llama3_content = """I'll help you get the weather information.
|
|
||||||
|
|
||||||
{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}"""
|
[get_weather(location="San Francisco, CA")]""",
|
||||||
|
"simple_content": "Hello! How can I help you today?",
|
||||||
|
}
|
||||||
|
|
||||||
self.llama4_content = """I'll help you get the weather information.
|
|
||||||
|
|
||||||
[get_weather(location="San Francisco, CA")]"""
|
def test_decode_with_no_model_id_defaults_to_llama3(test_content):
|
||||||
|
"""Test that decode_assistant_message defaults to Llama3 format when no model_id is provided."""
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
||||||
|
mock_formatter = Mock()
|
||||||
|
mock_chat_format.return_value = mock_formatter
|
||||||
|
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
|
||||||
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
|
|
||||||
self.simple_content = "Hello! How can I help you today?"
|
result = decode_assistant_message(test_content["simple_content"], StopReason.end_of_turn)
|
||||||
|
|
||||||
def test_decode_with_no_model_id_defaults_to_llama3(self):
|
mock_chat_format.assert_called_once()
|
||||||
"""Test that decode_assistant_message defaults to Llama3 format when no model_id is provided."""
|
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
test_content["simple_content"], StopReason.end_of_turn
|
||||||
mock_formatter = Mock()
|
|
||||||
mock_chat_format.return_value = mock_formatter
|
|
||||||
expected_message = RawMessage(role="assistant", content=self.simple_content)
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
|
||||||
|
|
||||||
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn)
|
|
||||||
|
|
||||||
mock_chat_format.assert_called_once()
|
|
||||||
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
|
|
||||||
self.simple_content, StopReason.end_of_turn
|
|
||||||
)
|
|
||||||
self.assertEqual(result, expected_message)
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
|
|
||||||
def test_decode_with_nonexistent_model_uses_llama3(self):
|
|
||||||
"""Test that decode_assistant_message uses Llama3 format for non-existent models."""
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
|
||||||
mock_resolve.return_value = None
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
|
||||||
mock_formatter = Mock()
|
|
||||||
mock_chat_format.return_value = mock_formatter
|
|
||||||
expected_message = RawMessage(role="assistant", content=self.simple_content)
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
|
||||||
|
|
||||||
result = decode_assistant_message(self.simple_content, StopReason.end_of_turn, "nonexistent-model")
|
|
||||||
|
|
||||||
mock_resolve.assert_called_once_with("nonexistent-model")
|
|
||||||
mock_chat_format.assert_called_once()
|
|
||||||
self.assertEqual(result, expected_message)
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
|
|
||||||
def test_decode_with_llama3_model_uses_llama3_format(self):
|
|
||||||
"""Test that decode_assistant_message uses Llama3 format for Llama3 models."""
|
|
||||||
mock_model = Mock()
|
|
||||||
mock_model.model_family = ModelFamily.llama3
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
|
||||||
mock_resolve.return_value = mock_model
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
|
||||||
mock_formatter = Mock()
|
|
||||||
mock_chat_format.return_value = mock_formatter
|
|
||||||
expected_message = RawMessage(role="assistant", content=self.llama3_content)
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
|
||||||
|
|
||||||
result = decode_assistant_message(
|
|
||||||
self.llama3_content, StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct"
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct")
|
|
||||||
mock_chat_format.assert_called_once()
|
|
||||||
self.assertEqual(result, expected_message)
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
|
|
||||||
def test_decode_with_llama4_model_uses_llama4_format(self):
|
|
||||||
"""Test that decode_assistant_message uses Llama4 format for Llama4 models when available."""
|
|
||||||
mock_model = Mock()
|
|
||||||
mock_model.model_family = ModelFamily.llama4
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
|
||||||
mock_resolve.return_value = mock_model
|
|
||||||
|
|
||||||
# Mock the Llama4 components
|
|
||||||
with patch(
|
|
||||||
"llama_stack.providers.utils.inference.prompt_adapter.Llama4ChatFormat", create=True
|
|
||||||
) as mock_llama4_format:
|
|
||||||
with patch(
|
|
||||||
"llama_stack.providers.utils.inference.prompt_adapter.Llama4Tokenizer", create=True
|
|
||||||
) as mock_llama4_tokenizer:
|
|
||||||
mock_tokenizer_instance = Mock()
|
|
||||||
mock_llama4_tokenizer.get_instance.return_value = mock_tokenizer_instance
|
|
||||||
|
|
||||||
mock_formatter = Mock()
|
|
||||||
mock_llama4_format.return_value = mock_formatter
|
|
||||||
expected_message = RawMessage(role="assistant", content=self.llama4_content)
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
|
||||||
|
|
||||||
result = decode_assistant_message(
|
|
||||||
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct")
|
|
||||||
mock_llama4_format.assert_called_once_with(mock_tokenizer_instance)
|
|
||||||
self.assertEqual(result, expected_message)
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False)
|
|
||||||
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(self):
|
|
||||||
"""Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable."""
|
|
||||||
mock_model = Mock()
|
|
||||||
mock_model.model_family = ModelFamily.llama4
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
|
||||||
mock_resolve.return_value = mock_model
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
|
||||||
mock_formatter = Mock()
|
|
||||||
mock_chat_format.return_value = mock_formatter
|
|
||||||
expected_message = RawMessage(role="assistant", content=self.llama4_content)
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
|
||||||
|
|
||||||
result = decode_assistant_message(
|
|
||||||
self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should NOT call resolve_model since LLAMA4_AVAILABLE is False
|
|
||||||
mock_resolve.assert_not_called()
|
|
||||||
mock_chat_format.assert_called_once()
|
|
||||||
self.assertEqual(result, expected_message)
|
|
||||||
|
|
||||||
def test_decode_with_different_stop_reasons(self):
|
|
||||||
"""Test that decode_assistant_message handles different stop reasons correctly."""
|
|
||||||
stop_reasons = [
|
|
||||||
StopReason.end_of_turn,
|
|
||||||
StopReason.end_of_message,
|
|
||||||
StopReason.out_of_tokens,
|
|
||||||
]
|
|
||||||
|
|
||||||
for stop_reason in stop_reasons:
|
|
||||||
with self.subTest(stop_reason=stop_reason):
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
|
||||||
mock_formatter = Mock()
|
|
||||||
mock_chat_format.return_value = mock_formatter
|
|
||||||
expected_message = RawMessage(role="assistant", content=self.simple_content)
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
|
||||||
|
|
||||||
result = decode_assistant_message(self.simple_content, stop_reason)
|
|
||||||
|
|
||||||
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
|
|
||||||
self.simple_content, stop_reason
|
|
||||||
)
|
|
||||||
self.assertEqual(result, expected_message)
|
|
||||||
|
|
||||||
def test_decode_preserves_formatter_response(self):
|
|
||||||
"""Test that decode_assistant_message preserves the formatter's response including tool calls."""
|
|
||||||
from llama_stack.apis.inference import ToolCall
|
|
||||||
|
|
||||||
mock_tool_call = ToolCall(
|
|
||||||
tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id"
|
|
||||||
)
|
)
|
||||||
|
assert result == expected_message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
|
||||||
|
def test_decode_with_nonexistent_model_uses_llama3(test_content):
|
||||||
|
"""Test that decode_assistant_message uses Llama3 format for non-existent models."""
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
||||||
|
mock_resolve.return_value = None
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
||||||
mock_formatter = Mock()
|
mock_formatter = Mock()
|
||||||
mock_chat_format.return_value = mock_formatter
|
mock_chat_format.return_value = mock_formatter
|
||||||
expected_message = RawMessage(
|
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
|
||||||
role="assistant", content="I'll help you get the weather.", tool_calls=[mock_tool_call]
|
|
||||||
)
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
|
|
||||||
result = decode_assistant_message(self.llama3_content, StopReason.end_of_turn)
|
result = decode_assistant_message(
|
||||||
|
test_content["simple_content"], StopReason.end_of_turn, "nonexistent-model"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(result, expected_message)
|
mock_resolve.assert_called_once_with("nonexistent-model")
|
||||||
self.assertEqual(len(result.tool_calls), 1)
|
mock_chat_format.assert_called_once()
|
||||||
self.assertEqual(result.tool_calls[0].tool_name, "get_weather")
|
assert result == expected_message
|
||||||
|
|
||||||
|
|
||||||
class TestDecodeAssistantMessageIntegration(unittest.TestCase):
|
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
|
||||||
"""Integration tests for decode_assistant_message with real model resolution."""
|
def test_decode_with_llama3_model_uses_llama3_format(test_content):
|
||||||
|
"""Test that decode_assistant_message uses Llama3 format for Llama3 models."""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_model.model_family = ModelFamily.llama3
|
||||||
|
|
||||||
def test_model_resolution_integration(self):
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
||||||
"""Test that model resolution works correctly with actual model IDs."""
|
mock_resolve.return_value = mock_model
|
||||||
# Test with actual model IDs that should resolve
|
|
||||||
test_cases = [
|
|
||||||
("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"),
|
|
||||||
("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"),
|
|
||||||
("invalid-model-id", "should fallback to Llama3"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for model_id, description in test_cases:
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
||||||
with self.subTest(model_id=model_id, description=description):
|
mock_formatter = Mock()
|
||||||
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
mock_chat_format.return_value = mock_formatter
|
||||||
mock_formatter = Mock()
|
expected_message = RawMessage(role="assistant", content=test_content["llama3_content"])
|
||||||
mock_chat_format.return_value = mock_formatter
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
expected_message = RawMessage(role="assistant", content="Test content")
|
|
||||||
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
|
||||||
|
|
||||||
# This should not raise an exception
|
result = decode_assistant_message(
|
||||||
result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id)
|
test_content["llama3_content"], StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(result, expected_message)
|
mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
mock_chat_format.assert_called_once()
|
||||||
|
assert result == expected_message
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True)
|
||||||
unittest.main()
|
def test_decode_with_llama4_model_uses_llama4_format(test_content):
|
||||||
|
"""Test that decode_assistant_message uses Llama4 format for Llama4 models when available."""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_model.model_family = ModelFamily.llama4
|
||||||
|
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
||||||
|
mock_resolve.return_value = mock_model
|
||||||
|
|
||||||
|
# Mock the Llama4 components
|
||||||
|
with patch(
|
||||||
|
"llama_stack.providers.utils.inference.prompt_adapter.Llama4ChatFormat", create=True
|
||||||
|
) as mock_llama4_format:
|
||||||
|
with patch(
|
||||||
|
"llama_stack.providers.utils.inference.prompt_adapter.Llama4Tokenizer", create=True
|
||||||
|
) as mock_llama4_tokenizer:
|
||||||
|
mock_tokenizer_instance = Mock()
|
||||||
|
mock_llama4_tokenizer.get_instance.return_value = mock_tokenizer_instance
|
||||||
|
|
||||||
|
mock_formatter = Mock()
|
||||||
|
mock_llama4_format.return_value = mock_formatter
|
||||||
|
expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
|
||||||
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
|
|
||||||
|
result = decode_assistant_message(
|
||||||
|
test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct")
|
||||||
|
mock_llama4_format.assert_called_once_with(mock_tokenizer_instance)
|
||||||
|
assert result == expected_message
|
||||||
|
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False)
|
||||||
|
def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(test_content):
|
||||||
|
"""Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable."""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_model.model_family = ModelFamily.llama4
|
||||||
|
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve:
|
||||||
|
mock_resolve.return_value = mock_model
|
||||||
|
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
||||||
|
mock_formatter = Mock()
|
||||||
|
mock_chat_format.return_value = mock_formatter
|
||||||
|
expected_message = RawMessage(role="assistant", content=test_content["llama4_content"])
|
||||||
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
|
|
||||||
|
result = decode_assistant_message(
|
||||||
|
test_content["llama4_content"], StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT call resolve_model since LLAMA4_AVAILABLE is False
|
||||||
|
mock_resolve.assert_not_called()
|
||||||
|
mock_chat_format.assert_called_once()
|
||||||
|
assert result == expected_message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"stop_reason",
|
||||||
|
[
|
||||||
|
StopReason.end_of_turn,
|
||||||
|
StopReason.end_of_message,
|
||||||
|
StopReason.out_of_tokens,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_decode_with_different_stop_reasons(test_content, stop_reason):
|
||||||
|
"""Test that decode_assistant_message handles different stop reasons correctly."""
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
||||||
|
mock_formatter = Mock()
|
||||||
|
mock_chat_format.return_value = mock_formatter
|
||||||
|
expected_message = RawMessage(role="assistant", content=test_content["simple_content"])
|
||||||
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
|
|
||||||
|
result = decode_assistant_message(test_content["simple_content"], stop_reason)
|
||||||
|
|
||||||
|
mock_formatter.decode_assistant_message_from_content.assert_called_once_with(
|
||||||
|
test_content["simple_content"], stop_reason
|
||||||
|
)
|
||||||
|
assert result == expected_message
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_preserves_formatter_response(test_content):
|
||||||
|
"""Test that decode_assistant_message preserves the formatter's response including tool calls."""
|
||||||
|
mock_tool_call = ToolCall(
|
||||||
|
tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
||||||
|
mock_formatter = Mock()
|
||||||
|
mock_chat_format.return_value = mock_formatter
|
||||||
|
expected_message = RawMessage(
|
||||||
|
role="assistant", content="I'll help you get the weather.", tool_calls=[mock_tool_call]
|
||||||
|
)
|
||||||
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
|
|
||||||
|
result = decode_assistant_message(test_content["llama3_content"], StopReason.end_of_turn)
|
||||||
|
|
||||||
|
assert result == expected_message
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].tool_name == "get_weather"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_id,description",
|
||||||
|
[
|
||||||
|
("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"),
|
||||||
|
("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"),
|
||||||
|
("invalid-model-id", "should fallback to Llama3"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_model_resolution_integration(model_id, description):
|
||||||
|
"""Test that model resolution works correctly with actual model IDs."""
|
||||||
|
with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format:
|
||||||
|
mock_formatter = Mock()
|
||||||
|
mock_chat_format.return_value = mock_formatter
|
||||||
|
expected_message = RawMessage(role="assistant", content="Test content")
|
||||||
|
mock_formatter.decode_assistant_message_from_content.return_value = expected_message
|
||||||
|
|
||||||
|
# This should not raise an exception
|
||||||
|
result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id)
|
||||||
|
|
||||||
|
assert result == expected_message
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue