mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
chore: move all Llama Stack types from llama-models to llama-stack (#1098)
llama-models should have extremely minimal cruft. Its sole purpose should be didactic -- show the simplest implementation of the llama models and document the prompt formats, etc. This PR is the complement to https://github.com/meta-llama/llama-models/pull/279 ## Test Plan Ensure all `llama` CLI `model` sub-commands work: ```bash llama model list llama model download --model-id ... llama model prompt-format -m ... ``` Ran tests: ```bash cd tests/client-sdk LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/ LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/ LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/ ``` Create a fresh venv `uv venv && source .venv/bin/activate` and run `llama stack build --template fireworks --image-type venv` followed by `llama stack run together --image-type venv` <-- the server runs Also checked that the OpenAPI generator can run and there is no change in the generated files as a result. ```bash cd docs/openapi_generator sh run_openapi_generator.sh ```
This commit is contained in:
parent
c0ee512980
commit
314ee09ae3
138 changed files with 8491 additions and 465 deletions
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
257
llama_stack/models/llama/llama3/interface.py
Normal file
257
llama_stack/models/llama/llama3/interface.py
Normal file
|
@ -0,0 +1,257 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
|
||||
from . import template_data
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
ToolResponseGenerator,
|
||||
)
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class Template:
|
||||
def __init__(
|
||||
self,
|
||||
role,
|
||||
template_name,
|
||||
data_provider=None,
|
||||
notes=None,
|
||||
):
|
||||
self.role = role
|
||||
self.template_name = template_name
|
||||
self.data_provider = data_provider or ""
|
||||
self._notes = notes or ""
|
||||
|
||||
@property
|
||||
def notes(self):
|
||||
default = "↵ represents newline"
|
||||
notes = default
|
||||
if self._notes:
|
||||
notes += "\n"
|
||||
notes += self._notes
|
||||
return notes
|
||||
|
||||
|
||||
TEMPLATES = [
|
||||
Template(
|
||||
"user",
|
||||
"user-default",
|
||||
"user_default",
|
||||
),
|
||||
Template(
|
||||
"user",
|
||||
"user-images",
|
||||
"user_images",
|
||||
),
|
||||
Template("user", "user-interleaved-images", "user_interleaved_images"),
|
||||
Template(
|
||||
"assistant",
|
||||
"assistant-builtin-tool-call",
|
||||
"assistant_builtin_tool_call",
|
||||
"Notice <|python_tag|>",
|
||||
),
|
||||
Template(
|
||||
"assistant",
|
||||
"assistant-custom-tool-call",
|
||||
"assistant_custom_tool_call",
|
||||
"Notice <function=...> format",
|
||||
),
|
||||
Template(
|
||||
"assistant",
|
||||
"assistant-default",
|
||||
"assistant_default",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-builtin-and-custom-tools",
|
||||
"system_message_builtin_and_custom_tools",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-builtin-tools-only",
|
||||
"system_message_builtin_tools_only",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-custom-tools-only",
|
||||
"system_message_custom_tools_only",
|
||||
),
|
||||
Template(
|
||||
"system",
|
||||
"system-default",
|
||||
"system_default",
|
||||
),
|
||||
Template(
|
||||
"tool",
|
||||
"tool-success",
|
||||
"tool_success",
|
||||
"Note ipython header and [stdout]",
|
||||
),
|
||||
Template(
|
||||
"tool",
|
||||
"tool-failure",
|
||||
"tool_failure",
|
||||
"Note ipython header and [stderr]",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class LLama31Interface:
|
||||
def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json):
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
self.tool_prompt_format = tool_prompt_format
|
||||
|
||||
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
|
||||
model_input = self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
self.tool_prompt_format,
|
||||
)
|
||||
return model_input.tokens
|
||||
|
||||
def tool_response_messages(self, *args, **kwargs):
|
||||
template = ToolResponseGenerator().gen(*args, **kwargs)
|
||||
return [
|
||||
RawMessage(
|
||||
role="tool",
|
||||
content=template.render(),
|
||||
)
|
||||
]
|
||||
|
||||
def system_messages(
|
||||
self,
|
||||
builtin_tools: List[BuiltinTool],
|
||||
custom_tools: List[ToolDefinition],
|
||||
instruction: Optional[str] = None,
|
||||
) -> List[RawMessage]:
|
||||
messages = []
|
||||
|
||||
default_gen = SystemDefaultGenerator()
|
||||
default_template = default_gen.gen()
|
||||
|
||||
sys_content = ""
|
||||
|
||||
tool_template = None
|
||||
if builtin_tools or custom_tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(builtin_tools + custom_tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
sys_content += default_template.render()
|
||||
|
||||
if instruction:
|
||||
sys_content += "\n\n"
|
||||
sys_content += instruction
|
||||
|
||||
sys_content += "\n"
|
||||
messages.append(RawMessage(role="system", content=sys_content))
|
||||
|
||||
if custom_tools:
|
||||
if self.tool_prompt_format == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
|
||||
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
messages.append(RawMessage(role="user", content=custom_template.render()))
|
||||
|
||||
return messages
|
||||
|
||||
def assistant_response_messages(
|
||||
self,
|
||||
content: str,
|
||||
stop_reason: StopReason,
|
||||
tool_call: Optional[ToolCall] = None,
|
||||
) -> List[RawMessage]:
|
||||
tool_calls = []
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
return [
|
||||
RawMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
]
|
||||
|
||||
def user_message(self, content: str) -> List[RawMessage]:
|
||||
return [RawMessage(role="user", content=content)]
|
||||
|
||||
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||
"""Util to print tokenized string to shell"""
|
||||
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
|
||||
on_colors = [
|
||||
"on_red",
|
||||
"on_green",
|
||||
"on_yellow",
|
||||
"on_blue",
|
||||
"on_magenta",
|
||||
"on_cyan",
|
||||
]
|
||||
for i, t in enumerate(tokens):
|
||||
on_col = on_colors[i % len(on_colors)]
|
||||
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
|
||||
print("\n", end="")
|
||||
|
||||
|
||||
def list_jinja_templates() -> List[Template]:
|
||||
return TEMPLATES
|
||||
|
||||
|
||||
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
|
||||
by_name = {t.template_name: t for t in TEMPLATES}
|
||||
if name not in by_name:
|
||||
raise ValueError(f"No template found for `{name}`")
|
||||
|
||||
template = by_name[name]
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
data_func = getattr(template_data, template.data_provider)
|
||||
if template.role == "system":
|
||||
messages = interface.system_messages(**data_func())
|
||||
elif template.role == "tool":
|
||||
messages = interface.tool_response_messages(**data_func())
|
||||
elif template.role == "assistant":
|
||||
messages = interface.assistant_response_messages(**data_func())
|
||||
elif template.role == "user":
|
||||
messages = interface.user_message(**data_func())
|
||||
|
||||
tokens = interface.get_tokens(messages)
|
||||
special_tokens = list(interface.tokenizer.special_tokens.values())
|
||||
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
|
||||
return template, tokens
|
BIN
llama_stack/models/llama/llama3/pasta.jpeg
Normal file
BIN
llama_stack/models/llama/llama3/pasta.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
22
llama_stack/models/llama/llama3/prompt_templates/__init__.py
Normal file
22
llama_stack/models/llama/llama3/prompt_templates/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase # noqa: F401
|
||||
from .system_prompts import ( # noqa: F401
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
from .tool_response import ToolResponseGenerator # noqa: F401
|
39
llama_stack/models/llama/llama3/prompt_templates/base.py
Normal file
39
llama_stack/models/llama/llama3/prompt_templates/base.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
data: Dict[str, Any]
|
||||
|
||||
def render(self):
|
||||
template = Template(self.template)
|
||||
return template.render(self.data)
|
||||
|
||||
|
||||
class PromptTemplateGeneratorBase:
|
||||
"""
|
||||
Base class for prompt template generators.
|
||||
"""
|
||||
|
||||
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||
raise NotImplementedError()
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
)
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||
|
||||
|
||||
class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: {{ today }}
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"today": datetime.now().strftime("%d %B %Y")},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
return [None]
|
||||
|
||||
|
||||
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||
def _tool_breakdown(self, tools: List[ToolDefinition]):
|
||||
builtin_tools, custom_tools = [], []
|
||||
for dfn in tools:
|
||||
if isinstance(dfn.tool_name, BuiltinTool):
|
||||
builtin_tools.append(dfn)
|
||||
else:
|
||||
custom_tools.append(dfn)
|
||||
|
||||
return builtin_tools, custom_tools
|
||||
|
||||
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
builtin_tools, custom_tools = self._tool_breakdown(tools)
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
{% if builtin_tools or custom_tools -%}
|
||||
Environment: ipython
|
||||
{% endif -%}
|
||||
{% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%}
|
||||
{% if builtin_tools -%}
|
||||
Tools: {{ builtin_tools | join(", ") | trim -}}
|
||||
{% endif %}
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{
|
||||
"builtin_tools": [t.tool_name.value for t in builtin_tools],
|
||||
"custom_tools": custom_tools,
|
||||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
# builtin tools
|
||||
[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
|
||||
],
|
||||
# only code interpretor
|
||||
[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{% for t in custom_tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{%- for name, param in tparams.items() %}
|
||||
{
|
||||
"{{name}}": {
|
||||
"type": "object",
|
||||
"description": "{{param.description}}"
|
||||
}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
],
|
||||
"required": {{ required_params | tojson }}
|
||||
}
|
||||
}
|
||||
}
|
||||
{% endfor %}
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
{% for t in custom_tools %}
|
||||
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set modified_params = t.parameters.copy() -%}
|
||||
{%- for key, value in modified_params.items() -%}
|
||||
{%- if 'default' in value -%}
|
||||
{%- set _ = value.pop('default', None) -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set tparams = modified_params | tojson -%}
|
||||
Use the function '{{ tname }}' to '{{ tdesc }}':
|
||||
{"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}}
|
||||
|
||||
{% endfor -%}
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||
DEFAULT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
{{ function_description }}
|
||||
""".strip("\n")
|
||||
)
|
||||
|
||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||
return PromptTemplate(
|
||||
system_prompt,
|
||||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{% for t in tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
}
|
||||
}{% if not loop.last %},
|
||||
{% endif -%}
|
||||
{%- endfor %}
|
||||
]
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.strip("\n"),
|
||||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
]
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import Optional
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||
|
||||
|
||||
class ToolResponseGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(
|
||||
self,
|
||||
status: str,
|
||||
stdout: Optional[str] = None,
|
||||
stderr: Optional[str] = None,
|
||||
):
|
||||
assert status in [
|
||||
"success",
|
||||
"failure",
|
||||
], f"status must be 'success' or 'failure'; Got: {status}"
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
{% if status == "success" %}completed{% else %}failed{% endif %}
|
||||
{%- if stdout %}
|
||||
[stdout]{{ stdout }}[/stdout]
|
||||
{%- endif -%}
|
||||
{%- if stderr %}
|
||||
[stderr]{{ stderr }}[/stderr]
|
||||
{%- endif -%}
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{
|
||||
"status": status,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
},
|
||||
)
|
||||
|
||||
def data_examples(self):
|
||||
return [
|
||||
# success
|
||||
{
|
||||
"status": "success",
|
||||
"stdout": '{"results":["something something"]}',
|
||||
},
|
||||
# failure
|
||||
{
|
||||
"status": "failure",
|
||||
"stderr": "brave_search encounter an error: could not communicate with api.brave.com",
|
||||
},
|
||||
]
|
120
llama_stack/models/llama/llama3/template_data.py
Normal file
120
llama_stack/models/llama/llama3/template_data.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
ToolResponseGenerator,
|
||||
)
|
||||
|
||||
INSTRUCTION = "You are a helpful assistant."
|
||||
|
||||
|
||||
def system_message_builtin_tools_only():
|
||||
return {
|
||||
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
|
||||
"custom_tools": [],
|
||||
"instruction": INSTRUCTION,
|
||||
}
|
||||
|
||||
|
||||
def system_message_builtin_code_only():
|
||||
return {
|
||||
"builtin_tools": BuiltinToolGenerator().data_examples()[1],
|
||||
"custom_tools": [],
|
||||
"instruction": "",
|
||||
}
|
||||
|
||||
|
||||
def system_message_custom_tools_only():
|
||||
return {
|
||||
"builtin_tools": [],
|
||||
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
|
||||
"instruction": INSTRUCTION,
|
||||
}
|
||||
|
||||
|
||||
def system_message_builtin_and_custom_tools():
|
||||
return {
|
||||
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
|
||||
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
|
||||
"instruction": INSTRUCTION,
|
||||
}
|
||||
|
||||
|
||||
def system_default():
|
||||
return {
|
||||
"builtin_tools": [],
|
||||
"custom_tools": [],
|
||||
"instruction": INSTRUCTION,
|
||||
}
|
||||
|
||||
|
||||
def tool_success():
|
||||
return ToolResponseGenerator().data_examples()[0]
|
||||
|
||||
|
||||
def tool_failure():
|
||||
return ToolResponseGenerator().data_examples()[1]
|
||||
|
||||
|
||||
def assistant_builtin_tool_call():
|
||||
return {
|
||||
"content": "",
|
||||
"tool_call": ToolCall(
|
||||
call_id="uuid",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments={
|
||||
"query": "Who won NBA in 2024?",
|
||||
},
|
||||
),
|
||||
"stop_reason": StopReason.end_of_message,
|
||||
}
|
||||
|
||||
|
||||
def assistant_custom_tool_call():
|
||||
return {
|
||||
"content": "",
|
||||
"tool_call": ToolCall(
|
||||
call_id="uuid",
|
||||
tool_name="trending_songs",
|
||||
arguments={"country": "US", "n": 10},
|
||||
),
|
||||
"stop_reason": StopReason.end_of_turn,
|
||||
}
|
||||
|
||||
|
||||
def assistant_default():
|
||||
return {
|
||||
"content": "Hi, I am a helpful assistant. What can I help you with today?",
|
||||
"tool_call": None,
|
||||
"stop_reason": StopReason.end_of_turn,
|
||||
}
|
||||
|
||||
|
||||
def user_default():
|
||||
return {"content": "Please tell me how to plan a trip to New York"}
|
||||
|
||||
|
||||
def user_images():
|
||||
return {"content": "<|image|><|image|>What do these images depict?"}
|
||||
|
||||
|
||||
def user_interleaved_images():
|
||||
return {"content": "<|image|>Describe the image in one sentence.<|image|>Write a haiku about these images"}
|
199
llama_stack/models/llama/llama3/test_system_prompts.py
Normal file
199
llama_stack/models/llama/llama3/test_system_prompts.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator, expected_text):
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
self.check_generator_output(generator, expected_text)
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]"""
|
||||
)
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
Loading…
Add table
Add a link
Reference in a new issue