From 8e9217774ab8ab4fec1e9c8852f0998eeb3fd1e1 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 24 Apr 2025 13:16:42 -0700 Subject: [PATCH] new prompt --- .../llama4/prompt_templates/system_prompts.py | 151 ++++++++++++++++++ .../utils/inference/prompt_adapter.py | 102 ++++++++---- 2 files changed, 226 insertions(+), 27 deletions(-) create mode 100644 llama_stack/models/llama/llama4/prompt_templates/system_prompts.py diff --git a/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py new file mode 100644 index 000000000..b9cc91d3b --- /dev/null +++ b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from datetime import datetime +from typing import Any, List, Optional + +from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition +from llama_stack.models.llama.llama3.prompt_templates.base import ( + PromptTemplate, + PromptTemplateGeneratorBase, +) + + +class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 + DEFAULT_PROMPT = textwrap.dedent( + """ + You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines: + + 1. FUNCTION CALLS: + - ONLY use functions that are EXPLICITLY listed in the function list below + - If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s) + - Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)] + Examples: + CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list + INCORRECT: get_weather(location="New York") + INCORRECT: Let me check the weather: [get_weather(location="New York")] + INCORRECT: [get_events(location="Singapore")] <- If function not in list + + 2. RESPONSE RULES: + - For pure function requests matching a listed function: ONLY output the function call(s) + - For knowledge questions: ONLY output text + - For missing parameters: ONLY request the specific missing parameters + - For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call. + - If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations + - NEVER combine text and function calls in the same response + - NEVER suggest alternative functions when the requested service is unavailable + - NEVER create or invent new functions not listed below + + 3. STRICT BOUNDARIES: + - ONLY use functions from the list below - no exceptions + - NEVER use a function as an alternative to unavailable information + - NEVER call functions not present in the function list + - NEVER add explanatory text to function calls + - NEVER respond with empty brackets + - Use proper Python/JSON syntax for function calls + - Check the function list carefully before responding + + 4. TOOL RESPONSE HANDLING: + - When receiving tool responses: provide concise, natural language responses + - Don't repeat tool response verbatim + - Don't add supplementary information + + + {{ function_description }} + """.strip( + "\n" + ) + ) + + def gen( + self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None + ) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description( + self, custom_tools: List[ToolDefinition] + ) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Here is a list of functions in JSON format that you can invoke. + + [ + {% for t in tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "dict", + "required": {{ required_params | tojson }}, + "properties": { + {%- for name, param in tparams.items() %} + "{{name}}": { + "type": "{{param.param_type}}", + "description": "{{param.description}}"{% if param.default %}, + "default": "{{param.default}}"{% endif %} + }{% if not loop.last %},{% endif %} + {%- endfor %} + } + } + }{% if not loop.last %}, + {% endif -%} + {%- endfor %} + ] + + You can answer general questions or invoke tools when necessary. + In addition to tool calls, you should also augment your responses by using the tool outputs. + + """ + ) + return PromptTemplate( + template_str.strip("\n"), + {"tools": [t.model_dump() for t in custom_tools]}, + ).render() + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="get_weather", + description="Get weather info for places", + parameters={ + "city": ToolParamDefinition( + param_type="string", + description="The name of the city to get the weather for", + required=True, + ), + "metric": ToolParamDefinition( + param_type="string", + description="The metric for weather. Options are: celsius, fahrenheit", + required=False, + default="celsius", + ), + }, + ), + ] + ] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 4f9c4927a..2568a4732 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -12,7 +12,6 @@ import re from typing import List, Optional, Tuple, Union import httpx -from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -52,9 +51,13 @@ from llama_stack.models.llama.llama3.prompt_templates import ( SystemDefaultGenerator, ) from llama_stack.models.llama.llama3.tokenizer import Tokenizer +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal +from llama_stack.models.llama.sku_types import is_multimodal, ModelFamily from llama_stack.providers.utils.inference import supported_inference_models +from PIL import Image as PIL_Image log = get_logger(name=__name__, category="inference") @@ -128,7 +131,9 @@ async def interleaved_content_convert_to_raw( if image.url.uri.startswith("data"): match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri) if not match: - raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...") + raise ValueError( + f"Invalid data URL format, {image.url.uri[:40]}..." + ) _, image_data = match.groups() data = base64.b64decode(image_data) elif image.url.uri.startswith("file://"): @@ -208,13 +213,17 @@ async def convert_image_content_to_url( content, format = await localize_image_content(media) if include_format: - return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8") + return f"data:image/{format};base64," + base64.b64encode(content).decode( + "utf-8" + ) else: return base64.b64encode(content).decode("utf-8") async def completion_request_to_prompt(request: CompletionRequest) -> str: - content = augment_content_with_response_format_prompt(request.response_format, request.content) + content = augment_content_with_response_format_prompt( + request.response_format, request.content + ) request.content = content request = await convert_request_to_raw(request) @@ -226,7 +235,9 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str: async def completion_request_to_prompt_model_input_info( request: CompletionRequest, ) -> Tuple[str, int]: - content = augment_content_with_response_format_prompt(request.response_format, request.content) + content = augment_content_with_response_format_prompt( + request.response_format, request.content + ) request.content = content request = await convert_request_to_raw(request) @@ -247,7 +258,9 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: +async def chat_completion_request_to_prompt( + request: ChatCompletionRequest, llama_model: str +) -> str: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) @@ -255,7 +268,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), + tool_prompt_format=request.tool_config.tool_prompt_format + or get_default_tool_prompt_format(llama_model), ) return formatter.tokenizer.decode(model_input.tokens) @@ -270,7 +284,8 @@ async def chat_completion_request_to_model_input_info( formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), + tool_prompt_format=request.tool_config.tool_prompt_format + or get_default_tool_prompt_format(llama_model), ) return ( formatter.tokenizer.decode(model_input.tokens), @@ -299,17 +314,23 @@ def chat_completion_request_to_messages( return request.messages if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) + model.model_family == ModelFamily.llama3_2 + and is_multimodal(model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format messages = augment_messages_for_tools_llama_3_1(request) elif model.model_family in ( ModelFamily.llama3_2, ModelFamily.llama3_3, - ModelFamily.llama4, ): - # llama3.2, llama3.3 and llama4 models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_2(request) + # llama3.2, llama3.3 follow the same tool prompt format + messages = augment_messages_for_tools_llama_3_2( + request, PythonListCustomToolGenerator + ) + elif model.model_family == ModelFamily.llama4: + messages = augment_messages_for_tools_llama_3_2( + request, PythonListCustomToolGeneratorLlama4 + ) else: messages = request.messages @@ -339,7 +360,9 @@ def augment_messages_for_tools_llama_3_1( if existing_messages[0].role == Role.system.value: existing_system_message = existing_messages.pop(0) - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" + assert ( + existing_messages[0].role != Role.system.value + ), "Should only have 1 system message" messages = [] @@ -371,9 +394,13 @@ def augment_messages_for_tools_llama_3_1( if isinstance(existing_system_message.content, str): sys_content += _process(existing_system_message.content) elif isinstance(existing_system_message.content, list): - sys_content += "\n".join([_process(c) for c in existing_system_message.content]) + sys_content += "\n".join( + [_process(c) for c in existing_system_message.content] + ) - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + tool_choice_prompt = _get_tool_choice_prompt( + request.tool_config.tool_choice, request.tools + ) if tool_choice_prompt: sys_content += "\n" + tool_choice_prompt @@ -401,13 +428,16 @@ def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama_3_2( request: ChatCompletionRequest, + custom_tool_prompt_generator, ) -> List[Message]: existing_messages = request.messages existing_system_message = None if existing_messages[0].role == Role.system.value: existing_system_message = existing_messages.pop(0) - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" + assert ( + existing_messages[0].role != Role.system.value + ), "Should only have 1 system message" sys_content = "" custom_tools, builtin_tools = [], [] @@ -428,23 +458,34 @@ def augment_messages_for_tools_llama_3_2( if custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list if fmt != ToolPromptFormat.python_list: - raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") + raise ValueError( + f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}" + ) system_prompt = None - if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: + if ( + existing_system_message + and request.tool_config.system_message_behavior + == SystemMessageBehavior.replace + ): system_prompt = existing_system_message.content - tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) + tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) sys_content += tool_template.render() sys_content += "\n" if existing_system_message and ( - request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools + request.tool_config.system_message_behavior == SystemMessageBehavior.append + or not custom_tools ): - sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") + sys_content += interleaved_content_as_str( + existing_system_message.content, sep="\n" + ) - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + tool_choice_prompt = _get_tool_choice_prompt( + request.tool_config.tool_choice, request.tools + ) if tool_choice_prompt: sys_content += "\n" + tool_choice_prompt @@ -452,11 +493,15 @@ def augment_messages_for_tools_llama_3_2( return messages -def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str: +def _get_tool_choice_prompt( + tool_choice: ToolChoice | str, tools: List[ToolDefinition] +) -> str: if tool_choice == ToolChoice.auto: return "" elif tool_choice == ToolChoice.required: - return "You MUST use one of the provided functions/tools to answer the user query." + return ( + "You MUST use one of the provided functions/tools to answer the user query." + ) elif tool_choice == ToolChoice.none: # tools are already not passed in return "" @@ -468,11 +513,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: llama_model = resolve_model(model) if llama_model is None: - log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format") + log.warning( + f"Could not resolve model {model}, defaulting to json tool prompt format" + ) return ToolPromptFormat.json if llama_model.model_family == ModelFamily.llama3_1 or ( - llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + llama_model.model_family == ModelFamily.llama3_2 + and is_multimodal(llama_model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format return ToolPromptFormat.json