diff --git a/llama_stack/models/llama/llama4/interface.py b/llama_stack/models/llama/llama4/interface.py new file mode 100644 index 000000000..850a0e239 --- /dev/null +++ b/llama_stack/models/llama/llama4/interface.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from pathlib import Path + +from termcolor import colored + +from ..datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolDefinition, + ToolPromptFormat, +) +from ..llama3.prompt_templates import ( + BuiltinToolGenerator, + ToolResponseGenerator, +) +from .chat_format import ChatFormat +from .prompt_templates.system_prompts import PythonListCustomToolGenerator +from .tokenizer import Tokenizer + +THIS_DIR = Path(__file__).parent + + +class Template: + def __init__( + self, + role, + template_name, + data_provider=None, + notes=None, + ): + self.role = role + self.template_name = template_name + self.data_provider = data_provider or "" + self._notes = notes or "" + + @property + def notes(self): + default = "↵ represents newline" + notes = default + if self._notes: + notes += "\n" + notes += self._notes + return notes + + +# Llama4 templates - similar to Llama3 but with python_list format +TEMPLATES = [ + Template( + "user", + "user-default", + "user_default", + ), + Template( + "user", + "user-images", + "user_images", + ), + Template("user", "user-interleaved-images", "user_interleaved_images"), + Template( + "assistant", + "assistant-builtin-tool-call", + "assistant_builtin_tool_call", + "Notice <|python_tag|>", + ), + Template( + "assistant", + "assistant-custom-tool-call", + "assistant_custom_tool_call", + "Notice [func_name(param=value)] format", + ), + Template( + "assistant", + "assistant-default", + "assistant_default", + ), + Template( + "system", + "system-builtin-and-custom-tools", + "system_message_builtin_and_custom_tools", + ), + Template( + "system", + "system-builtin-tools-only", + "system_message_builtin_tools_only", + ), + Template( + "system", + "system-custom-tools-only", + "system_message_custom_tools_only", + ), + Template( + "system", + "system-default", + "system_default", + ), + Template( + "tool", + "tool-success", + "tool_success", + "Note ipython header and [stdout]", + ), + Template( + "tool", + "tool-failure", + "tool_failure", + "Note ipython header and [stderr]", + ), +] + + +class Llama4Interface: + def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.python_list): + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) + self.tool_prompt_format = tool_prompt_format + + def get_tokens(self, messages: list[RawMessage]) -> list[int]: + model_input = self.formatter.encode_dialog_prompt( + messages, + self.tool_prompt_format, + ) + return model_input.tokens + + def tool_response_messages(self, *args, **kwargs): + template = ToolResponseGenerator().gen(*args, **kwargs) + return [ + RawMessage( + role="tool", + content=template.render(), + ) + ] + + def system_messages( + self, + builtin_tools: list[BuiltinTool], + custom_tools: list[ToolDefinition], + instruction: str | None = None, + ) -> list[RawMessage]: + messages = [] + + sys_content = "" + + # Handle builtin tools with builtin tool generator + if builtin_tools: + tool_gen = BuiltinToolGenerator() + tool_template = tool_gen.gen(builtin_tools) + sys_content += tool_template.render() + sys_content += "\n" + + # Handle custom tools with Llama4's python list generator + if custom_tools: + if self.tool_prompt_format != ToolPromptFormat.python_list: + raise ValueError(f"Llama4 only supports python_list tool prompt format, got {self.tool_prompt_format}") + + tool_gen = PythonListCustomToolGenerator() + tool_template = tool_gen.gen(custom_tools, instruction) + sys_content += tool_template.render() + else: + # If no custom tools but have instruction, add it + if instruction: + sys_content += instruction + + messages.append(RawMessage(role="system", content=sys_content.strip())) + + return messages + + def assistant_response_messages( + self, + content: str, + stop_reason: StopReason, + tool_call: ToolCall | None = None, + ) -> list[RawMessage]: + tool_calls = [] + if tool_call: + tool_calls.append(tool_call) + + return [ + RawMessage( + role="assistant", + content=content, + stop_reason=stop_reason, + tool_calls=tool_calls, + ) + ] + + def user_message(self, content: str) -> list[RawMessage]: + return [RawMessage(role="user", content=content)] + + def display_message_as_tokens(self, message: RawMessage) -> None: + tokens = self.formatter.encode_message(message, self.tool_prompt_format)[0] + decoded = [self.tokenizer.decode([t]) for t in tokens] + + print(f"\n{colored(f'Message ({message.role}):', 'yellow')}") + for i, (t, d) in enumerate(zip(tokens, decoded, strict=False)): + color = "light_blue" if d.startswith("<|") and d.endswith("|>") else "white" + print(f"{i:4d}: {t:6d} {colored(repr(d), color)}") + + +def list_jinja_templates() -> list[Template]: + return TEMPLATES + + +def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat): + # This would render templates - for now just return empty + # Can be implemented later if needed for Llama4-specific templates + return "" diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py index 271cec63f..730e9f770 100644 --- a/llama_stack/models/llama/sku_list.py +++ b/llama_stack/models/llama/sku_list.py @@ -22,6 +22,20 @@ def resolve_model(descriptor: str) -> Model | None: for m in all_registered_models(): if descriptor in (m.descriptor(), m.huggingface_repo): return m + + # Check provider aliases by attempting to import and check common providers + try: + from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES as TOGETHER_ENTRIES + + for entry in TOGETHER_ENTRIES: + if descriptor in entry.aliases and entry.llama_model: + # Find the model by its descriptor + for m in all_registered_models(): + if m.descriptor() == entry.llama_model: + return m + except ImportError: + pass + return None diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 47144ee0e..2febbb33e 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -299,7 +299,9 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) + raw_message = decode_assistant_message( + text_from_choice(choice), get_stop_reason(choice.finish_reason), request.model + ) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -448,7 +450,7 @@ async def process_chat_completion_stream_response( ) # parse tool calls and report errors - message = decode_assistant_message(buffer, stop_reason) + message = decode_assistant_message(buffer, stop_reason, request.model) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index bb9a91b97..2b0192f7f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -51,9 +51,19 @@ from llama_stack.models.llama.llama3.prompt_templates import ( SystemDefaultGenerator, ) from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( - PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, -) + +# Conditional imports to avoid heavy dependencies during module loading +try: + from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat + from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, + ) + from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer + + LLAMA4_AVAILABLE = True +except ImportError: + # Llama4 dependencies not available (e.g., torch not installed) + LLAMA4_AVAILABLE = False from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models @@ -69,8 +79,20 @@ class CompletionRequestWithRawContent(CompletionRequest): content: RawContent -def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage: - formatter = ChatFormat(Tokenizer.get_instance()) +def decode_assistant_message(content: str, stop_reason: StopReason, model_id: str | None = None) -> RawMessage: + """Decode assistant message using the appropriate formatter for the model family.""" + if model_id and LLAMA4_AVAILABLE: + model = resolve_model(model_id) + if model and model.model_family == ModelFamily.llama4: + # Use Llama4's ChatFormat for Llama4 models + formatter = Llama4ChatFormat(Llama4Tokenizer.get_instance()) + else: + # Use Llama3's ChatFormat for all other models (default) + formatter = ChatFormat(Tokenizer.get_instance()) + else: + # Default to Llama3 if no model specified or Llama4 not available + formatter = ChatFormat(Tokenizer.get_instance()) + return formatter.decode_assistant_message_from_content(content, stop_reason) diff --git a/tests/unit/models/test_decode_assistant_message.py b/tests/unit/models/test_decode_assistant_message.py new file mode 100644 index 000000000..feb0694a1 --- /dev/null +++ b/tests/unit/models/test_decode_assistant_message.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import unittest +from unittest.mock import Mock, patch + +from llama_stack.apis.inference import StopReason +from llama_stack.models.llama.datatypes import RawMessage +from llama_stack.models.llama.sku_types import ModelFamily +from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message + + +class TestDecodeAssistantMessage(unittest.TestCase): + """Test the decode_assistant_message function with different model types and formats.""" + + def setUp(self): + """Set up test fixtures.""" + self.llama3_content = """I'll help you get the weather information. + +{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}""" + + self.llama4_content = """I'll help you get the weather information. + +[get_weather(location="San Francisco, CA")]""" + + self.simple_content = "Hello! How can I help you today?" + + def test_decode_with_no_model_id_defaults_to_llama3(self): + """Test that decode_assistant_message defaults to Llama3 format when no model_id is provided.""" + with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: + mock_formatter = Mock() + mock_chat_format.return_value = mock_formatter + expected_message = RawMessage(role="assistant", content=self.simple_content) + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + result = decode_assistant_message(self.simple_content, StopReason.end_of_turn) + + mock_chat_format.assert_called_once() + mock_formatter.decode_assistant_message_from_content.assert_called_once_with( + self.simple_content, StopReason.end_of_turn + ) + self.assertEqual(result, expected_message) + + @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True) + def test_decode_with_nonexistent_model_uses_llama3(self): + """Test that decode_assistant_message uses Llama3 format for non-existent models.""" + with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve: + mock_resolve.return_value = None + + with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: + mock_formatter = Mock() + mock_chat_format.return_value = mock_formatter + expected_message = RawMessage(role="assistant", content=self.simple_content) + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + result = decode_assistant_message(self.simple_content, StopReason.end_of_turn, "nonexistent-model") + + mock_resolve.assert_called_once_with("nonexistent-model") + mock_chat_format.assert_called_once() + self.assertEqual(result, expected_message) + + @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True) + def test_decode_with_llama3_model_uses_llama3_format(self): + """Test that decode_assistant_message uses Llama3 format for Llama3 models.""" + mock_model = Mock() + mock_model.model_family = ModelFamily.llama3 + + with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve: + mock_resolve.return_value = mock_model + + with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: + mock_formatter = Mock() + mock_chat_format.return_value = mock_formatter + expected_message = RawMessage(role="assistant", content=self.llama3_content) + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + result = decode_assistant_message( + self.llama3_content, StopReason.end_of_turn, "meta-llama/Llama-3.1-8B-Instruct" + ) + + mock_resolve.assert_called_once_with("meta-llama/Llama-3.1-8B-Instruct") + mock_chat_format.assert_called_once() + self.assertEqual(result, expected_message) + + @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", True) + def test_decode_with_llama4_model_uses_llama4_format(self): + """Test that decode_assistant_message uses Llama4 format for Llama4 models when available.""" + mock_model = Mock() + mock_model.model_family = ModelFamily.llama4 + + with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve: + mock_resolve.return_value = mock_model + + # Mock the Llama4 components + with patch( + "llama_stack.providers.utils.inference.prompt_adapter.Llama4ChatFormat", create=True + ) as mock_llama4_format: + with patch( + "llama_stack.providers.utils.inference.prompt_adapter.Llama4Tokenizer", create=True + ) as mock_llama4_tokenizer: + mock_tokenizer_instance = Mock() + mock_llama4_tokenizer.get_instance.return_value = mock_tokenizer_instance + + mock_formatter = Mock() + mock_llama4_format.return_value = mock_formatter + expected_message = RawMessage(role="assistant", content=self.llama4_content) + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + result = decode_assistant_message( + self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct" + ) + + mock_resolve.assert_called_once_with("meta-llama/Llama-4-8B-Instruct") + mock_llama4_format.assert_called_once_with(mock_tokenizer_instance) + self.assertEqual(result, expected_message) + + @patch("llama_stack.providers.utils.inference.prompt_adapter.LLAMA4_AVAILABLE", False) + def test_decode_with_llama4_model_falls_back_to_llama3_when_unavailable(self): + """Test that decode_assistant_message falls back to Llama3 format when Llama4 is unavailable.""" + mock_model = Mock() + mock_model.model_family = ModelFamily.llama4 + + with patch("llama_stack.providers.utils.inference.prompt_adapter.resolve_model") as mock_resolve: + mock_resolve.return_value = mock_model + + with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: + mock_formatter = Mock() + mock_chat_format.return_value = mock_formatter + expected_message = RawMessage(role="assistant", content=self.llama4_content) + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + result = decode_assistant_message( + self.llama4_content, StopReason.end_of_turn, "meta-llama/Llama-4-8B-Instruct" + ) + + # Should NOT call resolve_model since LLAMA4_AVAILABLE is False + mock_resolve.assert_not_called() + mock_chat_format.assert_called_once() + self.assertEqual(result, expected_message) + + def test_decode_with_different_stop_reasons(self): + """Test that decode_assistant_message handles different stop reasons correctly.""" + stop_reasons = [ + StopReason.end_of_turn, + StopReason.end_of_message, + StopReason.out_of_tokens, + ] + + for stop_reason in stop_reasons: + with self.subTest(stop_reason=stop_reason): + with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: + mock_formatter = Mock() + mock_chat_format.return_value = mock_formatter + expected_message = RawMessage(role="assistant", content=self.simple_content) + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + result = decode_assistant_message(self.simple_content, stop_reason) + + mock_formatter.decode_assistant_message_from_content.assert_called_once_with( + self.simple_content, stop_reason + ) + self.assertEqual(result, expected_message) + + def test_decode_preserves_formatter_response(self): + """Test that decode_assistant_message preserves the formatter's response including tool calls.""" + from llama_stack.apis.inference import ToolCall + + mock_tool_call = ToolCall( + tool_name="get_weather", arguments={"location": "San Francisco, CA"}, call_id="test_call_id" + ) + + with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: + mock_formatter = Mock() + mock_chat_format.return_value = mock_formatter + expected_message = RawMessage( + role="assistant", content="I'll help you get the weather.", tool_calls=[mock_tool_call] + ) + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + result = decode_assistant_message(self.llama3_content, StopReason.end_of_turn) + + self.assertEqual(result, expected_message) + self.assertEqual(len(result.tool_calls), 1) + self.assertEqual(result.tool_calls[0].tool_name, "get_weather") + + +class TestDecodeAssistantMessageIntegration(unittest.TestCase): + """Integration tests for decode_assistant_message with real model resolution.""" + + def test_model_resolution_integration(self): + """Test that model resolution works correctly with actual model IDs.""" + # Test with actual model IDs that should resolve + test_cases = [ + ("meta-llama/Llama-3.1-8B-Instruct", "should resolve to Llama3"), + ("meta-llama/Llama-4-8B-Instruct", "should resolve to Llama4 if available"), + ("invalid-model-id", "should fallback to Llama3"), + ] + + for model_id, description in test_cases: + with self.subTest(model_id=model_id, description=description): + with patch("llama_stack.providers.utils.inference.prompt_adapter.ChatFormat") as mock_chat_format: + mock_formatter = Mock() + mock_chat_format.return_value = mock_formatter + expected_message = RawMessage(role="assistant", content="Test content") + mock_formatter.decode_assistant_message_from_content.return_value = expected_message + + # This should not raise an exception + result = decode_assistant_message("Test content", StopReason.end_of_turn, model_id) + + self.assertEqual(result, expected_message) + + +if __name__ == "__main__": + unittest.main()