mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
new prompt
This commit is contained in:
parent
7ed137e963
commit
8e9217774a
2 changed files with 226 additions and 27 deletions
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
PromptTemplate,
|
||||
PromptTemplateGeneratorBase,
|
||||
)
|
||||
|
||||
|
||||
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||
DEFAULT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||
|
||||
1. FUNCTION CALLS:
|
||||
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||
Examples:
|
||||
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||
INCORRECT: get_weather(location="New York")
|
||||
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||
|
||||
2. RESPONSE RULES:
|
||||
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||
- For knowledge questions: ONLY output text
|
||||
- For missing parameters: ONLY request the specific missing parameters
|
||||
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||
- NEVER combine text and function calls in the same response
|
||||
- NEVER suggest alternative functions when the requested service is unavailable
|
||||
- NEVER create or invent new functions not listed below
|
||||
|
||||
3. STRICT BOUNDARIES:
|
||||
- ONLY use functions from the list below - no exceptions
|
||||
- NEVER use a function as an alternative to unavailable information
|
||||
- NEVER call functions not present in the function list
|
||||
- NEVER add explanatory text to function calls
|
||||
- NEVER respond with empty brackets
|
||||
- Use proper Python/JSON syntax for function calls
|
||||
- Check the function list carefully before responding
|
||||
|
||||
4. TOOL RESPONSE HANDLING:
|
||||
- When receiving tool responses: provide concise, natural language responses
|
||||
- Don't repeat tool response verbatim
|
||||
- Don't add supplementary information
|
||||
|
||||
|
||||
{{ function_description }}
|
||||
""".strip(
|
||||
"\n"
|
||||
)
|
||||
)
|
||||
|
||||
def gen(
|
||||
self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None
|
||||
) -> PromptTemplate:
|
||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||
return PromptTemplate(
|
||||
system_prompt,
|
||||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(
|
||||
self, custom_tools: List[ToolDefinition]
|
||||
) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{% for t in tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
}
|
||||
}{% if not loop.last %},
|
||||
{% endif -%}
|
||||
{%- endfor %}
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.strip("\n"),
|
||||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
]
|
|
@ -12,7 +12,6 @@ import re
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
|
@ -52,9 +51,13 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.models.llama.sku_types import is_multimodal, ModelFamily
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
@ -128,7 +131,9 @@ async def interleaved_content_convert_to_raw(
|
|||
if image.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
|
||||
raise ValueError(
|
||||
f"Invalid data URL format, {image.url.uri[:40]}..."
|
||||
)
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif image.url.uri.startswith("file://"):
|
||||
|
@ -208,13 +213,17 @@ async def convert_image_content_to_url(
|
|||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
else:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
|
@ -226,7 +235,9 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
|||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest,
|
||||
) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
|
@ -247,7 +258,9 @@ def augment_content_with_response_format_prompt(response_format, content):
|
|||
return content
|
||||
|
||||
|
||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||
async def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
@ -255,7 +268,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
|||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
@ -270,7 +284,8 @@ async def chat_completion_request_to_model_input_info(
|
|||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
|
@ -299,17 +314,23 @@ def chat_completion_request_to_messages(
|
|||
return request.messages
|
||||
|
||||
if model.model_family == ModelFamily.llama3_1 or (
|
||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
||||
model.model_family == ModelFamily.llama3_2
|
||||
and is_multimodal(model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
# llama3.2, llama3.3 follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(
|
||||
request, PythonListCustomToolGenerator
|
||||
)
|
||||
elif model.model_family == ModelFamily.llama4:
|
||||
messages = augment_messages_for_tools_llama_3_2(
|
||||
request, PythonListCustomToolGeneratorLlama4
|
||||
)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
||||
|
@ -339,7 +360,9 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
|
||||
|
@ -371,9 +394,13 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if isinstance(existing_system_message.content, str):
|
||||
sys_content += _process(existing_system_message.content)
|
||||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||
sys_content += "\n".join(
|
||||
[_process(c) for c in existing_system_message.content]
|
||||
)
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
tool_choice_prompt = _get_tool_choice_prompt(
|
||||
request.tool_config.tool_choice, request.tools
|
||||
)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
|
@ -401,13 +428,16 @@ def augment_messages_for_tools_llama_3_1(
|
|||
|
||||
def augment_messages_for_tools_llama_3_2(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> List[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
|
||||
sys_content = ""
|
||||
custom_tools, builtin_tools = [], []
|
||||
|
@ -428,23 +458,34 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if custom_tools:
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
|
||||
)
|
||||
|
||||
system_prompt = None
|
||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
if (
|
||||
existing_system_message
|
||||
and request.tool_config.system_message_behavior
|
||||
== SystemMessageBehavior.replace
|
||||
):
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
if existing_system_message and (
|
||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append
|
||||
or not custom_tools
|
||||
):
|
||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||
sys_content += interleaved_content_as_str(
|
||||
existing_system_message.content, sep="\n"
|
||||
)
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
tool_choice_prompt = _get_tool_choice_prompt(
|
||||
request.tool_config.tool_choice, request.tools
|
||||
)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
|
@ -452,11 +493,15 @@ def augment_messages_for_tools_llama_3_2(
|
|||
return messages
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
|
||||
def _get_tool_choice_prompt(
|
||||
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
|
||||
) -> str:
|
||||
if tool_choice == ToolChoice.auto:
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required:
|
||||
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||
return (
|
||||
"You MUST use one of the provided functions/tools to answer the user query."
|
||||
)
|
||||
elif tool_choice == ToolChoice.none:
|
||||
# tools are already not passed in
|
||||
return ""
|
||||
|
@ -468,11 +513,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
|||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||
llama_model = resolve_model(model)
|
||||
if llama_model is None:
|
||||
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
log.warning(
|
||||
f"Could not resolve model {model}, defaulting to json tool prompt format"
|
||||
)
|
||||
return ToolPromptFormat.json
|
||||
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||
llama_model.model_family == ModelFamily.llama3_2
|
||||
and is_multimodal(llama_model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
return ToolPromptFormat.json
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue