new prompt

This commit is contained in:
Eric Huang 2025-04-24 13:16:42 -07:00
parent 7ed137e963
commit 8e9217774a
2 changed files with 226 additions and 27 deletions

View file

@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from datetime import datetime
from typing import Any, List, Optional
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
PromptTemplate,
PromptTemplateGeneratorBase,
)
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
DEFAULT_PROMPT = textwrap.dedent(
"""
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
1. FUNCTION CALLS:
- ONLY use functions that are EXPLICITLY listed in the function list below
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
Examples:
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
INCORRECT: get_weather(location="New York")
INCORRECT: Let me check the weather: [get_weather(location="New York")]
INCORRECT: [get_events(location="Singapore")] <- If function not in list
2. RESPONSE RULES:
- For pure function requests matching a listed function: ONLY output the function call(s)
- For knowledge questions: ONLY output text
- For missing parameters: ONLY request the specific missing parameters
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
- NEVER combine text and function calls in the same response
- NEVER suggest alternative functions when the requested service is unavailable
- NEVER create or invent new functions not listed below
3. STRICT BOUNDARIES:
- ONLY use functions from the list below - no exceptions
- NEVER use a function as an alternative to unavailable information
- NEVER call functions not present in the function list
- NEVER add explanatory text to function calls
- NEVER respond with empty brackets
- Use proper Python/JSON syntax for function calls
- Check the function list carefully before responding
4. TOOL RESPONSE HANDLING:
- When receiving tool responses: provide concise, natural language responses
- Don't repeat tool response verbatim
- Don't add supplementary information
{{ function_description }}
""".strip(
"\n"
)
)
def gen(
self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None
) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(
self, custom_tools: List[ToolDefinition]
) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
[
{% for t in tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
{%- for name, param in tparams.items() if param.required == true -%}
{%- set _ = required_params.append(name) -%}
{%- endfor -%}
{
"name": "{{tname}}",
"description": "{{tdesc}}",
"parameters": {
"type": "dict",
"required": {{ required_params | tojson }},
"properties": {
{%- for name, param in tparams.items() %}
"{{name}}": {
"type": "{{param.param_type}}",
"description": "{{param.description}}"{% if param.default %},
"default": "{{param.default}}"{% endif %}
}{% if not loop.last %},{% endif %}
{%- endfor %}
}
}
}{% if not loop.last %},
{% endif -%}
{%- endfor %}
]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.
"""
)
return PromptTemplate(
template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]},
).render()
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
ToolDefinition(
tool_name="get_weather",
description="Get weather info for places",
parameters={
"city": ToolParamDefinition(
param_type="string",
description="The name of the city to get the weather for",
required=True,
),
"metric": ToolParamDefinition(
param_type="string",
description="The metric for weather. Options are: celsius, fahrenheit",
required=False,
default="celsius",
),
},
),
]
]

View file

@ -12,7 +12,6 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -52,9 +51,13 @@ from llama_stack.models.llama.llama3.prompt_templates import (
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.models.llama.sku_types import is_multimodal, ModelFamily
from llama_stack.providers.utils.inference import supported_inference_models
from PIL import Image as PIL_Image
log = get_logger(name=__name__, category="inference")
@ -128,7 +131,9 @@ async def interleaved_content_convert_to_raw(
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
raise ValueError(
f"Invalid data URL format, {image.url.uri[:40]}..."
)
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"):
@ -208,13 +213,17 @@ async def convert_image_content_to_url(
content, format = await localize_image_content(media)
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
)
else:
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -226,7 +235,9 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest,
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -247,7 +258,9 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
async def chat_completion_request_to_prompt(
request: ChatCompletionRequest, llama_model: str
) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
@ -255,7 +268,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
@ -270,7 +284,8 @@ async def chat_completion_request_to_model_input_info(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
return (
formatter.tokenizer.decode(model_input.tokens),
@ -299,17 +314,23 @@ def chat_completion_request_to_messages(
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
model.model_family == ModelFamily.llama3_2
and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(request)
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(
request, PythonListCustomToolGenerator
)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama_3_2(
request, PythonListCustomToolGeneratorLlama4
)
else:
messages = request.messages
@ -339,7 +360,9 @@ def augment_messages_for_tools_llama_3_1(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
@ -371,9 +394,13 @@ def augment_messages_for_tools_llama_3_1(
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -401,13 +428,16 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> List[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
@ -428,23 +458,34 @@ def augment_messages_for_tools_llama_3_2(
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
)
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
if (
existing_system_message
and request.tool_config.system_message_behavior
== SystemMessageBehavior.replace
):
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
request.tool_config.system_message_behavior == SystemMessageBehavior.append
or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -452,11 +493,15 @@ def augment_messages_for_tools_llama_3_2(
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
def _get_tool_choice_prompt(
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required:
return "You MUST use one of the provided functions/tools to answer the user query."
return (
"You MUST use one of the provided functions/tools to answer the user query."
)
elif tool_choice == ToolChoice.none:
# tools are already not passed in
return ""
@ -468,11 +513,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
log.warning(
f"Could not resolve model {model}, defaulting to json tool prompt format"
)
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
llama_model.model_family == ModelFamily.llama3_2
and is_multimodal(llama_model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json