From 29072f40ab8bf8d47cb6867192e1b2f232f89321 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 25 Apr 2025 11:29:08 -0700 Subject: [PATCH] feat: new system prompt for llama4 (#2031) Tests: LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct Co-authored-by: Eric Huang --- .../llama4/prompt_templates/system_prompts.py | 144 ++++++++++++++++++ .../utils/inference/prompt_adapter.py | 15 +- 2 files changed, 154 insertions(+), 5 deletions(-) create mode 100644 llama_stack/models/llama/llama4/prompt_templates/system_prompts.py diff --git a/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py new file mode 100644 index 000000000..139e204ad --- /dev/null +++ b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List, Optional + +from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition +from llama_stack.models.llama.llama3.prompt_templates.base import ( + PromptTemplate, + PromptTemplateGeneratorBase, +) + + +class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 + DEFAULT_PROMPT = textwrap.dedent( + """ + You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines: + + 1. FUNCTION CALLS: + - ONLY use functions that are EXPLICITLY listed in the function list below + - If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s) + - Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)] + Examples: + CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list + INCORRECT: get_weather(location="New York") + INCORRECT: Let me check the weather: [get_weather(location="New York")] + INCORRECT: [get_events(location="Singapore")] <- If function not in list + + 2. RESPONSE RULES: + - For pure function requests matching a listed function: ONLY output the function call(s) + - For knowledge questions: ONLY output text + - For missing parameters: ONLY request the specific missing parameters + - For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call. + - If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations + - NEVER combine text and function calls in the same response + - NEVER suggest alternative functions when the requested service is unavailable + - NEVER create or invent new functions not listed below + + 3. STRICT BOUNDARIES: + - ONLY use functions from the list below - no exceptions + - NEVER use a function as an alternative to unavailable information + - NEVER call functions not present in the function list + - NEVER add explanatory text to function calls + - NEVER respond with empty brackets + - Use proper Python/JSON syntax for function calls + - Check the function list carefully before responding + + 4. TOOL RESPONSE HANDLING: + - When receiving tool responses: provide concise, natural language responses + - Don't repeat tool response verbatim + - Don't add supplementary information + + + {{ function_description }} + """.strip("\n") + ) + + def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Here is a list of functions in JSON format that you can invoke. + + [ + {% for t in tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "dict", + "required": {{ required_params | tojson }}, + "properties": { + {%- for name, param in tparams.items() %} + "{{name}}": { + "type": "{{param.param_type}}", + "description": "{{param.description}}"{% if param.default %}, + "default": "{{param.default}}"{% endif %} + }{% if not loop.last %},{% endif %} + {%- endfor %} + } + } + }{% if not loop.last %}, + {% endif -%} + {%- endfor %} + ] + + You can answer general questions or invoke tools when necessary. + In addition to tool calls, you should also augment your responses by using the tool outputs. + + """ + ) + return PromptTemplate( + template_str.strip("\n"), + {"tools": [t.model_dump() for t in custom_tools]}, + ).render() + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="get_weather", + description="Get weather info for places", + parameters={ + "city": ToolParamDefinition( + param_type="string", + description="The name of the city to get the weather for", + required=True, + ), + "metric": ToolParamDefinition( + param_type="string", + description="The metric for weather. Options are: celsius, fahrenheit", + required=False, + default="celsius", + ), + }, + ), + ] + ] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 4f9c4927a..657dc4b86 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import ( SystemDefaultGenerator, ) from llama_stack.models.llama.llama3.tokenizer import Tokenizer +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models @@ -306,10 +309,11 @@ def chat_completion_request_to_messages( elif model.model_family in ( ModelFamily.llama3_2, ModelFamily.llama3_3, - ModelFamily.llama4, ): - # llama3.2, llama3.3 and llama4 models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_2(request) + # llama3.2, llama3.3 follow the same tool prompt format + messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator) + elif model.model_family == ModelFamily.llama4: + messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4) else: messages = request.messages @@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1( return messages -def augment_messages_for_tools_llama_3_2( +def augment_messages_for_tools_llama( request: ChatCompletionRequest, + custom_tool_prompt_generator, ) -> List[Message]: existing_messages = request.messages existing_system_message = None @@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2( if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: system_prompt = existing_system_message.content - tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) + tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) sys_content += tool_template.render() sys_content += "\n"