mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: new system prompt for llama4 (#2031)
Tests: LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct Co-authored-by: Eric Huang <erichuang@fb.com>
This commit is contained in:
parent
4bbd0c0693
commit
29072f40ab
2 changed files with 154 additions and 5 deletions
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||||
|
PromptTemplate,
|
||||||
|
PromptTemplateGeneratorBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||||
|
|
||||||
|
1. FUNCTION CALLS:
|
||||||
|
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||||
|
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||||
|
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||||
|
Examples:
|
||||||
|
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||||
|
INCORRECT: get_weather(location="New York")
|
||||||
|
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||||
|
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||||
|
|
||||||
|
2. RESPONSE RULES:
|
||||||
|
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||||
|
- For knowledge questions: ONLY output text
|
||||||
|
- For missing parameters: ONLY request the specific missing parameters
|
||||||
|
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||||
|
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||||
|
- NEVER combine text and function calls in the same response
|
||||||
|
- NEVER suggest alternative functions when the requested service is unavailable
|
||||||
|
- NEVER create or invent new functions not listed below
|
||||||
|
|
||||||
|
3. STRICT BOUNDARIES:
|
||||||
|
- ONLY use functions from the list below - no exceptions
|
||||||
|
- NEVER use a function as an alternative to unavailable information
|
||||||
|
- NEVER call functions not present in the function list
|
||||||
|
- NEVER add explanatory text to function calls
|
||||||
|
- NEVER respond with empty brackets
|
||||||
|
- Use proper Python/JSON syntax for function calls
|
||||||
|
- Check the function list carefully before responding
|
||||||
|
|
||||||
|
4. TOOL RESPONSE HANDLING:
|
||||||
|
- When receiving tool responses: provide concise, natural language responses
|
||||||
|
- Don't repeat tool response verbatim
|
||||||
|
- Don't add supplementary information
|
||||||
|
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
""".strip("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||||
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
|
return PromptTemplate(
|
||||||
|
system_prompt,
|
||||||
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{% for t in tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": {{ required_params | tojson }},
|
||||||
|
"properties": {
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "{{param.param_type}}",
|
||||||
|
"description": "{{param.description}}"{% if param.default %},
|
||||||
|
"default": "{{param.default}}"{% endif %}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},
|
||||||
|
{% endif -%}
|
||||||
|
{%- endfor %}
|
||||||
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.strip("\n"),
|
||||||
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
).render()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get weather info for places",
|
||||||
|
parameters={
|
||||||
|
"city": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The name of the city to get the weather for",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"metric": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
|
@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
|
||||||
elif model.model_family in (
|
elif model.model_family in (
|
||||||
ModelFamily.llama3_2,
|
ModelFamily.llama3_2,
|
||||||
ModelFamily.llama3_3,
|
ModelFamily.llama3_3,
|
||||||
ModelFamily.llama4,
|
|
||||||
):
|
):
|
||||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
# llama3.2, llama3.3 follow the same tool prompt format
|
||||||
messages = augment_messages_for_tools_llama_3_2(request)
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||||
|
elif model.model_family == ModelFamily.llama4:
|
||||||
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||||
else:
|
else:
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
|
|
||||||
|
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama_3_2(
|
def augment_messages_for_tools_llama(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
|
custom_tool_prompt_generator,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||||
system_prompt = existing_system_message.content
|
system_prompt = existing_system_message.content
|
||||||
|
|
||||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
sys_content += tool_template.render()
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue