forked from phoenix-oss/llama-stack-mirror
feat: new system prompt for llama4 (#2031)
Tests: LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct Co-authored-by: Eric Huang <erichuang@fb.com>
This commit is contained in:
parent
4bbd0c0693
commit
29072f40ab
2 changed files with 154 additions and 5 deletions
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
PromptTemplate,
|
||||
PromptTemplateGeneratorBase,
|
||||
)
|
||||
|
||||
|
||||
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||
DEFAULT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||
|
||||
1. FUNCTION CALLS:
|
||||
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||
Examples:
|
||||
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||
INCORRECT: get_weather(location="New York")
|
||||
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||
|
||||
2. RESPONSE RULES:
|
||||
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||
- For knowledge questions: ONLY output text
|
||||
- For missing parameters: ONLY request the specific missing parameters
|
||||
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||
- NEVER combine text and function calls in the same response
|
||||
- NEVER suggest alternative functions when the requested service is unavailable
|
||||
- NEVER create or invent new functions not listed below
|
||||
|
||||
3. STRICT BOUNDARIES:
|
||||
- ONLY use functions from the list below - no exceptions
|
||||
- NEVER use a function as an alternative to unavailable information
|
||||
- NEVER call functions not present in the function list
|
||||
- NEVER add explanatory text to function calls
|
||||
- NEVER respond with empty brackets
|
||||
- Use proper Python/JSON syntax for function calls
|
||||
- Check the function list carefully before responding
|
||||
|
||||
4. TOOL RESPONSE HANDLING:
|
||||
- When receiving tool responses: provide concise, natural language responses
|
||||
- Don't repeat tool response verbatim
|
||||
- Don't add supplementary information
|
||||
|
||||
|
||||
{{ function_description }}
|
||||
""".strip("\n")
|
||||
)
|
||||
|
||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||
return PromptTemplate(
|
||||
system_prompt,
|
||||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{% for t in tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
}
|
||||
}{% if not loop.last %},
|
||||
{% endif -%}
|
||||
{%- endfor %}
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.strip("\n"),
|
||||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
]
|
|
@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
|
|||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
# llama3.2, llama3.3 follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||
elif model.model_family == ModelFamily.llama4:
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
||||
|
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
|
|||
return messages
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama_3_2(
|
||||
def augment_messages_for_tools_llama(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> List[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
|
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue