diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index 48109e5d8..dc0e1e9d3 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -16,12 +16,11 @@ from pathlib import Path import fire import ruamel.yaml as yaml -from llama_models import schema_utils +from llama_stack import schema_utils # We do some monkey-patching to ensure our definitions only use the minimal -# (json_schema_type, webmethod) definitions from the llama_models package. For -# generation though, we need the full definitions and implementations from the -# (json-strong-typing) package. +# (json_schema_type, webmethod) definitions. For generation though, we need +# the full definitions and implementations from the (json-strong-typing) package. from .strong_typing.schema import json_schema_type, register_schema diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py new file mode 100644 index 000000000..b99c90d75 --- /dev/null +++ b/llama_stack/models/llama/datatypes.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from enum import Enum +from typing import Any, Dict, Literal, Optional, Union + +from llama_models.datatypes import BuiltinTool, ToolCall +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing_extensions import Annotated + +from llama_stack.schema_utils import json_schema_type, register_schema + +register_schema(ToolCall) + + +@json_schema_type +class ToolParamDefinition(BaseModel): + param_type: str + description: Optional[str] = None + required: Optional[bool] = True + default: Optional[Any] = None + + +@json_schema_type +class ToolDefinition(BaseModel): + tool_name: Union[BuiltinTool, str] + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @field_validator("tool_name", mode="before") + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinTool(v) + except ValueError: + return v + return v + + +@json_schema_type +class GreedySamplingStrategy(BaseModel): + type: Literal["greedy"] = "greedy" + + +@json_schema_type +class TopPSamplingStrategy(BaseModel): + type: Literal["top_p"] = "top_p" + temperature: Optional[float] = Field(..., gt=0.0) + top_p: Optional[float] = 0.95 + + +@json_schema_type +class TopKSamplingStrategy(BaseModel): + type: Literal["top_k"] = "top_k" + top_k: int = Field(..., ge=1) + + +SamplingStrategy = register_schema( + Annotated[ + Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + Field(discriminator="type"), + ], + name="SamplingStrategy", +) + + +@json_schema_type +class SamplingParams(BaseModel): + strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) + + max_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + + +class CheckpointQuantizationFormat(Enum): + # default format + bf16 = "bf16" + + # used for enabling fp8_rowwise inference, some weights are bf16 + fp8_mixed = "fp8-mixed" + + int8 = "int8" + + int4 = "int4" + + +class ModelFamily(Enum): + llama2 = "llama2" + llama3 = "llama3" + llama3_1 = "llama3_1" + llama3_2 = "llama3_2" + llama3_3 = "llama3_3" + safety = "safety" + + +class CoreModelId(Enum): + """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)""" + + # Llama 2 family + llama2_7b = "Llama-2-7b" + llama2_13b = "Llama-2-13b" + llama2_70b = "Llama-2-70b" + llama2_7b_chat = "Llama-2-7b-chat" + llama2_13b_chat = "Llama-2-13b-chat" + llama2_70b_chat = "Llama-2-70b-chat" + + # Llama 3 family + llama3_8b = "Llama-3-8B" + llama3_70b = "Llama-3-70B" + llama3_8b_instruct = "Llama-3-8B-Instruct" + llama3_70b_instruct = "Llama-3-70B-Instruct" + + # Llama 3.1 family + llama3_1_8b = "Llama3.1-8B" + llama3_1_70b = "Llama3.1-70B" + llama3_1_405b = "Llama3.1-405B" + llama3_1_8b_instruct = "Llama3.1-8B-Instruct" + llama3_1_70b_instruct = "Llama3.1-70B-Instruct" + llama3_1_405b_instruct = "Llama3.1-405B-Instruct" + + # Llama 3.2 family + llama3_2_1b = "Llama3.2-1B" + llama3_2_3b = "Llama3.2-3B" + llama3_2_1b_instruct = "Llama3.2-1B-Instruct" + llama3_2_3b_instruct = "Llama3.2-3B-Instruct" + llama3_2_11b_vision = "Llama3.2-11B-Vision" + llama3_2_90b_vision = "Llama3.2-90B-Vision" + llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct" + llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct" + + # Llama 3.3 family + llama3_3_70b_instruct = "Llama3.3-70B-Instruct" + + # Safety models + llama_guard_3_8b = "Llama-Guard-3-8B" + llama_guard_2_8b = "Llama-Guard-2-8B" + llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision" + llama_guard_3_1b = "Llama-Guard-3-1B" + + +def is_multimodal(model_id) -> bool: + if model_id in [ + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return True + else: + return False + + +def model_family(model_id) -> ModelFamily: + if model_id in [ + CoreModelId.llama2_7b, + CoreModelId.llama2_13b, + CoreModelId.llama2_70b, + CoreModelId.llama2_7b_chat, + CoreModelId.llama2_13b_chat, + CoreModelId.llama2_70b_chat, + ]: + return ModelFamily.llama2 + elif model_id in [ + CoreModelId.llama3_8b, + CoreModelId.llama3_70b, + CoreModelId.llama3_8b_instruct, + CoreModelId.llama3_70b_instruct, + ]: + return ModelFamily.llama3 + elif model_id in [ + CoreModelId.llama3_1_8b, + CoreModelId.llama3_1_70b, + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_405b_instruct, + ]: + return ModelFamily.llama3_1 + elif model_id in [ + CoreModelId.llama3_2_1b, + CoreModelId.llama3_2_3b, + CoreModelId.llama3_2_1b_instruct, + CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return ModelFamily.llama3_2 + elif model_id in [ + CoreModelId.llama3_3_70b_instruct, + ]: + return ModelFamily.llama3_3 + elif model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_2_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return ModelFamily.safety + else: + raise ValueError(f"Unknown model family for {model_id}") + + +class Model(BaseModel): + core_model_id: CoreModelId + description: str + huggingface_repo: Optional[str] = None + recommended_sampling_params: Optional[SamplingParams] = None + arch_args: Dict[str, Any] + variant: str = "" + + quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 + pth_file_count: int + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + # silence pydantic until we remove the `model_` fields + model_config = ConfigDict(protected_namespaces=()) + + @property + def model_family(self) -> ModelFamily: + return model_family(self.core_model_id) + + # The SKU is uniquely identified by (model_id, variant) combo + def descriptor(self, shorten_default_variant: bool = True) -> str: + if not self.variant: + return self.core_model_id.value + return f"{self.core_model_id.value}:{self.variant}" + + @property + def is_instruct_model(self) -> bool: + return "instruct" in self.id.name + + # Featured models are shown in the non-exhaustive model list + @property + def is_featured(self) -> bool: + return self.model_family in [ + ModelFamily.llama3_1, + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.safety, + ] + + @property + def max_seq_length(self) -> int: + if self.model_family == ModelFamily.llama2: + return 4096 + elif self.core_model_id == CoreModelId.llama_guard_2_8b: + return 4096 + elif self.model_family == ModelFamily.llama3: + return 8192 + elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]: + return 131072 + elif self.model_family == ModelFamily.llama3_2: + if self.quantization_format == CheckpointQuantizationFormat.int4: + return 8192 + return 131072 + elif self.core_model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return 131072 + else: + raise ValueError(f"Unknown max_seq_len for {self.core_model_id}") diff --git a/llama_stack/models/llama/llama3/dog.jpg b/llama_stack/models/llama/llama3/dog.jpg new file mode 100644 index 000000000..f9a3a8057 Binary files /dev/null and b/llama_stack/models/llama/llama3/dog.jpg differ diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py new file mode 100644 index 000000000..bc42228a5 --- /dev/null +++ b/llama_stack/models/llama/llama3/interface.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from pathlib import Path +from typing import List, Optional + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer +from termcolor import colored + +from llama_stack.models.llama.datatypes import ToolDefinition + +from . import template_data +from .prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + SystemDefaultGenerator, + ToolResponseGenerator, +) + +THIS_DIR = Path(__file__).parent + + +class Template: + def __init__( + self, + role, + template_name, + data_provider=None, + notes=None, + ): + self.role = role + self.template_name = template_name + self.data_provider = data_provider or "" + self._notes = notes or "" + + @property + def notes(self): + default = "↵ represents newline" + notes = default + if self._notes: + notes += "\n" + notes += self._notes + return notes + + +TEMPLATES = [ + Template( + "user", + "user-default", + "user_default", + ), + Template( + "user", + "user-images", + "user_images", + ), + Template("user", "user-interleaved-images", "user_interleaved_images"), + Template( + "assistant", + "assistant-builtin-tool-call", + "assistant_builtin_tool_call", + "Notice <|python_tag|>", + ), + Template( + "assistant", + "assistant-custom-tool-call", + "assistant_custom_tool_call", + "Notice format", + ), + Template( + "assistant", + "assistant-default", + "assistant_default", + ), + Template( + "system", + "system-builtin-and-custom-tools", + "system_message_builtin_and_custom_tools", + ), + Template( + "system", + "system-builtin-tools-only", + "system_message_builtin_tools_only", + ), + Template( + "system", + "system-custom-tools-only", + "system_message_custom_tools_only", + ), + Template( + "system", + "system-default", + "system_default", + ), + Template( + "tool", + "tool-success", + "tool_success", + "Note ipython header and [stdout]", + ), + Template( + "tool", + "tool-failure", + "tool_failure", + "Note ipython header and [stderr]", + ), +] + + +class LLama31Interface: + def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json): + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) + self.tool_prompt_format = tool_prompt_format + + def get_tokens(self, messages: List[RawMessage]) -> List[int]: + model_input = self.formatter.encode_dialog_prompt( + messages, + self.tool_prompt_format, + ) + return model_input.tokens + + def tool_response_messages(self, *args, **kwargs): + template = ToolResponseGenerator().gen(*args, **kwargs) + return [ + RawMessage( + role="tool", + content=template.render(), + ) + ] + + def system_messages( + self, + builtin_tools: List[BuiltinTool], + custom_tools: List[ToolDefinition], + instruction: Optional[str] = None, + ) -> List[RawMessage]: + messages = [] + + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + + sys_content = "" + + tool_template = None + if builtin_tools or custom_tools: + tool_gen = BuiltinToolGenerator() + tool_template = tool_gen.gen(builtin_tools + custom_tools) + + sys_content += tool_template.render() + sys_content += "\n" + + sys_content += default_template.render() + + if instruction: + sys_content += "\n\n" + sys_content += instruction + + sys_content += "\n" + messages.append(RawMessage(role="system", content=sys_content)) + + if custom_tools: + if self.tool_prompt_format == ToolPromptFormat.json: + tool_gen = JsonCustomToolGenerator() + elif self.tool_prompt_format == ToolPromptFormat.function_tag: + tool_gen = FunctionTagCustomToolGenerator() + else: + raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}") + + custom_template = tool_gen.gen(custom_tools) + messages.append(RawMessage(role="user", content=custom_template.render())) + + return messages + + def assistant_response_messages( + self, + content: str, + stop_reason: StopReason, + tool_call: Optional[ToolCall] = None, + ) -> List[RawMessage]: + tool_calls = [] + if tool_call: + tool_calls.append(tool_call) + return [ + RawMessage( + role="assistant", + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + ) + ] + + def user_message(self, content: str) -> List[RawMessage]: + return [RawMessage(role="user", content=content)] + + def display_message_as_tokens(self, message: RawMessage) -> None: + """Util to print tokenized string to shell""" + tokens = self.formatter.encode_message(message, self.tool_prompt_format) + on_colors = [ + "on_red", + "on_green", + "on_yellow", + "on_blue", + "on_magenta", + "on_cyan", + ] + for i, t in enumerate(tokens): + on_col = on_colors[i % len(on_colors)] + print(colored(self.tokenizer.decode([t]), "white", on_col), end="") + print("\n", end="") + + +def list_jinja_templates() -> List[Template]: + return TEMPLATES + + +def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat): + by_name = {t.template_name: t for t in TEMPLATES} + if name not in by_name: + raise ValueError(f"No template found for `{name}`") + + template = by_name[name] + interface = LLama31Interface(tool_prompt_format) + + data_func = getattr(template_data, template.data_provider) + if template.role == "system": + messages = interface.system_messages(**data_func()) + elif template.role == "tool": + messages = interface.tool_response_messages(**data_func()) + elif template.role == "assistant": + messages = interface.assistant_response_messages(**data_func()) + elif template.role == "user": + messages = interface.user_message(**data_func()) + + tokens = interface.get_tokens(messages) + special_tokens = list(interface.tokenizer.special_tokens.values()) + tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] + return template, tokens diff --git a/llama_stack/models/llama/llama3/pasta.jpeg b/llama_stack/models/llama/llama3/pasta.jpeg new file mode 100644 index 000000000..e8299321c Binary files /dev/null and b/llama_stack/models/llama/llama3/pasta.jpeg differ diff --git a/llama_stack/models/llama/llama3/prompt_templates/__init__.py b/llama_stack/models/llama/llama3/prompt_templates/__init__.py new file mode 100644 index 000000000..4eed54d12 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from .base import PromptTemplate, PromptTemplateGeneratorBase # noqa: F401 +from .system_prompts import ( # noqa: F401 + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) +from .tool_response import ToolResponseGenerator # noqa: F401 diff --git a/llama_stack/models/llama/llama3/prompt_templates/base.py b/llama_stack/models/llama/llama3/prompt_templates/base.py new file mode 100644 index 000000000..bff2a21e1 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/base.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from dataclasses import dataclass +from typing import Any, Dict, List + +from jinja2 import Template + + +@dataclass +class PromptTemplate: + template: str + data: Dict[str, Any] + + def render(self): + template = Template(self.template) + return template.render(self.data) + + +class PromptTemplateGeneratorBase: + """ + Base class for prompt template generators. + """ + + def gen(self, *args, **kwargs) -> PromptTemplate: + raise NotImplementedError() + + def data_examples(self) -> List[Any]: + raise NotImplementedError() diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py new file mode 100644 index 000000000..27b1a3502 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from datetime import datetime +from typing import Any, List, Optional + +from llama_models.datatypes import ( + BuiltinTool, +) + +from llama_stack.models.llama.datatypes import ( + ToolDefinition, + ToolParamDefinition, +) + +from .base import PromptTemplate, PromptTemplateGeneratorBase + + +class SystemDefaultGenerator(PromptTemplateGeneratorBase): + def gen(self, *args, **kwargs) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Cutting Knowledge Date: December 2023 + Today Date: {{ today }} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"today": datetime.now().strftime("%d %B %Y")}, + ) + + def data_examples(self) -> List[Any]: + return [None] + + +class BuiltinToolGenerator(PromptTemplateGeneratorBase): + def _tool_breakdown(self, tools: List[ToolDefinition]): + builtin_tools, custom_tools = [], [] + for dfn in tools: + if isinstance(dfn.tool_name, BuiltinTool): + builtin_tools.append(dfn) + else: + custom_tools.append(dfn) + + return builtin_tools, custom_tools + + def gen(self, tools: List[ToolDefinition]) -> PromptTemplate: + builtin_tools, custom_tools = self._tool_breakdown(tools) + template_str = textwrap.dedent( + """ + {% if builtin_tools or custom_tools -%} + Environment: ipython + {% endif -%} + {% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%} + {% if builtin_tools -%} + Tools: {{ builtin_tools | join(", ") | trim -}} + {% endif %} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + { + "builtin_tools": [t.tool_name.value for t in builtin_tools], + "custom_tools": custom_tools, + }, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + # builtin tools + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), + ], + # only code interpretor + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + ] + + +class JsonCustomToolGenerator(PromptTemplateGeneratorBase): + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + {% for t in custom_tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "type": "function", + "function": { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "object", + "properties": [ + {%- for name, param in tparams.items() %} + { + "{{name}}": { + "type": "object", + "description": "{{param.description}}" + } + }{% if not loop.last %},{% endif %} + {%- endfor %} + ], + "required": {{ required_params | tojson }} + } + } + } + {% endfor %} + Return function calls in JSON format. + """ + ) + + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] + + +class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + You have access to the following functions: + + {% for t in custom_tools %} + {#- manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set modified_params = t.parameters.copy() -%} + {%- for key, value in modified_params.items() -%} + {%- if 'default' in value -%} + {%- set _ = value.pop('default', None) -%} + {%- endif -%} + {%- endfor -%} + {%- set tparams = modified_params | tojson -%} + Use the function '{{ tname }}' to '{{ tdesc }}': + {"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}} + + {% endfor -%} + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] + + +class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 + DEFAULT_PROMPT = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + {{ function_description }} + """.strip("\n") + ) + + def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + {% for t in tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "dict", + "required": {{ required_params | tojson }}, + "properties": { + {%- for name, param in tparams.items() %} + "{{name}}": { + "type": "{{param.param_type}}", + "description": "{{param.description}}"{% if param.default %}, + "default": "{{param.default}}"{% endif %} + }{% if not loop.last %},{% endif %} + {%- endfor %} + } + } + }{% if not loop.last %}, + {% endif -%} + {%- endfor %} + ] + """ + ) + return PromptTemplate( + template_str.strip("\n"), + {"tools": [t.model_dump() for t in custom_tools]}, + ).render() + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="get_weather", + description="Get weather info for places", + parameters={ + "city": ToolParamDefinition( + param_type="string", + description="The name of the city to get the weather for", + required=True, + ), + "metric": ToolParamDefinition( + param_type="string", + description="The metric for weather. Options are: celsius, fahrenheit", + required=False, + default="celsius", + ), + }, + ), + ] + ] diff --git a/llama_stack/models/llama/llama3/prompt_templates/tool_response.py b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py new file mode 100644 index 000000000..3df4dac14 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import Optional + +from .base import PromptTemplate, PromptTemplateGeneratorBase + + +class ToolResponseGenerator(PromptTemplateGeneratorBase): + def gen( + self, + status: str, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + ): + assert status in [ + "success", + "failure", + ], f"status must be 'success' or 'failure'; Got: {status}" + template_str = textwrap.dedent( + """ + {% if status == "success" %}completed{% else %}failed{% endif %} + {%- if stdout %} + [stdout]{{ stdout }}[/stdout] + {%- endif -%} + {%- if stderr %} + [stderr]{{ stderr }}[/stderr] + {%- endif -%} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + { + "status": status, + "stdout": stdout, + "stderr": stderr, + }, + ) + + def data_examples(self): + return [ + # success + { + "status": "success", + "stdout": '{"results":["something something"]}', + }, + # failure + { + "status": "failure", + "stderr": "brave_search encounter an error: could not communicate with api.brave.com", + }, + ] diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py new file mode 100644 index 000000000..620816ffc --- /dev/null +++ b/llama_stack/models/llama/llama3/template_data.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from llama_models.datatypes import ( + BuiltinTool, + StopReason, + ToolCall, +) + +from .prompt_templates import ( + BuiltinToolGenerator, + JsonCustomToolGenerator, + ToolResponseGenerator, +) + +INSTRUCTION = "You are a helpful assistant." + + +def system_message_builtin_tools_only(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[0], + "custom_tools": [], + "instruction": INSTRUCTION, + } + + +def system_message_builtin_code_only(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[1], + "custom_tools": [], + "instruction": "", + } + + +def system_message_custom_tools_only(): + return { + "builtin_tools": [], + "custom_tools": JsonCustomToolGenerator().data_examples()[0], + "instruction": INSTRUCTION, + } + + +def system_message_builtin_and_custom_tools(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[0], + "custom_tools": JsonCustomToolGenerator().data_examples()[0], + "instruction": INSTRUCTION, + } + + +def system_default(): + return { + "builtin_tools": [], + "custom_tools": [], + "instruction": INSTRUCTION, + } + + +def tool_success(): + return ToolResponseGenerator().data_examples()[0] + + +def tool_failure(): + return ToolResponseGenerator().data_examples()[1] + + +def assistant_builtin_tool_call(): + return { + "content": "", + "tool_call": ToolCall( + call_id="uuid", + tool_name=BuiltinTool.brave_search, + arguments={ + "query": "Who won NBA in 2024?", + }, + ), + "stop_reason": StopReason.end_of_message, + } + + +def assistant_custom_tool_call(): + return { + "content": "", + "tool_call": ToolCall( + call_id="uuid", + tool_name="trending_songs", + arguments={"country": "US", "n": 10}, + ), + "stop_reason": StopReason.end_of_turn, + } + + +def assistant_default(): + return { + "content": "Hi, I am a helpful assistant. What can I help you with today?", + "tool_call": None, + "stop_reason": StopReason.end_of_turn, + } + + +def user_default(): + return {"content": "Please tell me how to plan a trip to New York"} + + +def user_images(): + return {"content": "<|image|><|image|>What do these images depict?"} + + +def user_interleaved_images(): + return {"content": "<|image|>Describe the image in one sentence.<|image|>Write a haiku about these images"} diff --git a/llama_stack/models/llama/llama3/test_system_prompts.py b/llama_stack/models/llama/llama3/test_system_prompts.py new file mode 100644 index 000000000..40fd93891 --- /dev/null +++ b/llama_stack/models/llama/llama3/test_system_prompts.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +import unittest +from datetime import datetime + +from .prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) + + +class PromptTemplateTests(unittest.TestCase): + def check_generator_output(self, generator, expected_text): + example = generator.data_examples()[0] + + pt = generator.gen(example) + text = pt.render() + # print(text) # debugging + assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}" + + def test_system_default(self): + generator = SystemDefaultGenerator() + today = datetime.now().strftime("%d %B %Y") + expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" + self.check_generator_output(generator, expected_text) + + def test_system_builtin_only(self): + generator = BuiltinToolGenerator() + expected_text = textwrap.dedent( + """ + Environment: ipython + Tools: brave_search, wolfram_alpha + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_only(self): + self.maxDiff = None + generator = JsonCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + { + "type": "function", + "function": { + "name": "trending_songs", + "description": "Returns the trending songs on a Music site", + "parameters": { + "type": "object", + "properties": [ + { + "n": { + "type": "object", + "description": "The number of songs to return" + } + }, + { + "genre": { + "type": "object", + "description": "The genre of the songs to return" + } + } + ], + "required": ["n"] + } + } + } + + Return function calls in JSON format. + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_function_tag(self): + self.maxDiff = None + generator = FunctionTagCustomToolGenerator() + expected_text = textwrap.dedent( + """ + You have access to the following functions: + + Use the function 'trending_songs' to 'Returns the trending songs on a Music site': + {"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}} + + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_llama_3_2_system_zero_shot(self): + generator = PythonListCustomToolGenerator() + expected_text = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": ["city"], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ] + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_llama_3_2_provided_system_prompt(self): + generator = PythonListCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Overriding message. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": ["city"], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ]""" + ) + user_system_prompt = textwrap.dedent( + """ + Overriding message. + + {{ function_description }} + """ + ) + example = generator.data_examples()[0] + + pt = generator.gen(example, user_system_prompt) + text = pt.render() + assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}" diff --git a/llama_stack/models/llama/llama3_1/__init__.py b/llama_stack/models/llama/llama3_1/__init__.py new file mode 100644 index 000000000..38ee47d66 --- /dev/null +++ b/llama_stack/models/llama/llama3_1/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_1/prompts.py b/llama_stack/models/llama/llama3_1/prompts.py new file mode 100644 index 000000000..edbce3bc0 --- /dev/null +++ b/llama_stack/models/llama/llama3_1/prompts.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + # llama3_1_e2e_tool_call_dialog, + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + llama3_1_custom_tool_call_dialog, +) + + +def wolfram_alpha_response(): + return textwrap.dedent( + """ + { + "queryresult": { + "success": true, + "inputstring": "100th decimal of pi", + "pods": [ + { + "title": "Input interpretation", + "subpods": [ + { + "title": "", + "plaintext": "100th digit | \u03c0" + } + ] + }, + { + "title": "Nearby digits", + "subpods": [ + { + "title": "", + "plaintext": "...86208998628034825342117067982148086513282306647093..." + } + ] + }, + { + "title": "Result", + "primary": true, + "subpods": [ + { + "title": "", + "plaintext": "7" + } + ] + } + ] + } + } + """ + ) + + +def usecases() -> List[UseCase | str]: + return [ + textwrap.dedent( + """ + # Llama 3.1 - Prompt Formats + ## Tokens + Here is a list of special tokens that are supported by Llama 3.1: + - `<|begin_of_text|>`: Specifies the start of the prompt + - `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models. + - `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch. + - `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool] + - `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool. + - `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios: + - at the end of a direct interaction between the model and the user + - at the end of multiple interactions between the model and any available tools + This token signals to the executor that the model has finished generating a response. + - `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call. + """ + ), + textwrap.dedent( + """ + There are 4 different roles that are supported by Llama 3.1 + - `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively. + - `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model. + - `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".) + - `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts. + """ + ), + UseCase( + title="Llama 3.1 Base Model", + description="Text completion for Llama 3.1 base model uses this format.", + dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")], + notes="Note start special tag", + ), + "## Llama 3.1 Instruct Model", + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage( + role="user", + content="Answer who are you in the form of jeopardy?", + ), + ] + ], + notes="", + ), + "## Tool Calling Formats", + textwrap.dedent( + """ + The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: + - Brave Search: Tool call to perform web searches. + - Wolfram Alpha: Tool call to perform complex mathematical calculations. + - Code Interpreter: Enables the model to output python code. + """ + ), + UseCase( + title="Builtin Tool Calling", + description=textwrap.dedent( + """ + Here is an example of a conversation using brave search + """ + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model. + - The message body of the assistant response starts with a special tag <|python_tag|> + - As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call. + - The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha` + """ + ), + ), + UseCase( + title="Builtin Code Interpreter", + description="Here is an actual example of model responding with code", + dialogs=[ + [ + RawMessage(role="system", content="Environment: ipython"), + RawMessage( + role="user", + content="Write code to check if number is prime, use that to see if the number 7 is prime", + ), + ], + ], + notes=textwrap.dedent( + """ + - Model starts with <|python_tag|> and continues writing python code that it needs to be executed + - No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it. + """ + ), + ), + UseCase( + title="Built-in tools full interaction", + description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.", + dialogs=[ + [ + RawMessage( + role="system", + content="Environment: ipython\nTools: brave_search, wolfram_alpha\n", + ), + RawMessage(role="user", content="What is the 100th decimal of pi?"), + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="tool_call_id", + tool_name=BuiltinTool.wolfram_alpha, + arguments={"query": "100th decimal of pi"}, + ) + ], + ), + RawMessage( + role="tool", + content=wolfram_alpha_response(), + ), + ], + ], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` in the assistant response. + - Role is `tool` for the wolfram alpha response that is passed back to the model. + - Final message from assistant has <|eot_id|> tag. + """ + ), + ), + "## Zero shot tool calling", + UseCase( + title="JSON based tool calling", + description=textwrap.dedent( + """ + Llama models can now output custom tool calls from a single message to allow easier tool calling. + The following prompts provide an example of how custom tools can be called from the output of the model. + It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog()], + notes=textwrap.dedent( + """ + - JSON format for providing tools needs name, description and parameters + - Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt + - Instructions for tools added as a user message + - Only single tool calls are supported as of now + """ + ), + ), + # FIXME: This is not working yet as expected + # UseCase( + # title="E2E tool call example", + # description=textwrap.dedent( + # """ + # Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model. + # """ + # ), + # dialogs=[ + # llama3_1_e2e_tool_call_dialog( + # tool_prompt_format=ToolPromptFormat.function_tag + # ) + # ], + # notes="", + # ), + "## Example of a user defined tool calling", + UseCase( + title="`` based tool calling", + description=textwrap.dedent( + """ + Here is an example of how you could also write custom instructions for model to do zero shot tool calling. + In this example, we define a custom tool calling format using the `` tag. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)], + notes=textwrap.dedent( + """ + - In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>` + - Instructions for tools added as a user message + """ + ), + ), + ] diff --git a/llama_stack/models/llama/llama3_2/__init__.py b/llama_stack/models/llama/llama3_2/__init__.py new file mode 100644 index 000000000..38ee47d66 --- /dev/null +++ b/llama_stack/models/llama/llama3_2/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_2/prompts_text.py b/llama_stack/models/llama/llama3_2/prompts_text.py new file mode 100644 index 000000000..29557f4be --- /dev/null +++ b/llama_stack/models/llama/llama3_2/prompts_text.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. +import json +import textwrap + +from llama_models.datatypes import ( + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + TextCompletionContent, + UseCase, + llama3_1_builtin_code_interpreter_dialog, +) + + +def user_tool_call(): + content = textwrap.dedent( + """ + Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request? + Here is a list of functions in JSON format that you can invoke: + [ + { + "name": "get_user_info", + "description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.", + "parameters": { + "type": "dict", + "required": [ + "user_id" + ], + "properties": { + "user_id": { + "type": "integer", + "description": "The unique identifier of the user. It is used to fetch the specific user details from the database." + }, + "special": { + "type": "string", + "description": "Any special information or parameters that need to be considered while fetching user details.", + "default": "none" + } + } + } + } + ] + + Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)] + + NO other text MUST be included. + """ + ) + return content.strip() + + +def system_tool_call(): + content = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": [ + "city" + ], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ] + """ + ) + return content.strip() + + +def usecases(): + return [ + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage(role="user", content="Who are you?"), + ] + ], + notes="This format is unchanged from Llama3.1", + ), + UseCase( + title="Zero shot function calling", + description=textwrap.dedent( + """ + For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling. + This new format is designed to be more flexible and powerful than the previous format. + All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls. + It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `` tag that were defined in Llama3.1. + Here is an example for the same, + """ + ), + dialogs=[ + # Zero shot tool calls as system message + [ + RawMessage(role="system", content=system_tool_call()), + RawMessage(role="user", content="What is the weather in SF and Seattle?"), + ], + ], + notes=textwrap.dedent( + """ + - The output supports multiple tool calls natively + - JSON format for defining the functions in the system prompt is similar to Llama3.1 + """ + ), + ), + UseCase( + title="Zero shot function calling with user message", + description=textwrap.dedent( + """ + While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message. + """ + ), + dialogs=[ + # Zero shot tool call as user message + [ + RawMessage(role="user", content=user_tool_call()), + ], + ], + notes=textwrap.dedent( + """ + - The tool call format for the model is the same whether your function calls are provided in the system or user message. + - While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls. + """ + ), + ), + UseCase( + title="Code Interpreter", + description=textwrap.dedent( + """ + Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family. + Here is an example, + """ + ), + dialogs=[llama3_1_builtin_code_interpreter_dialog()], + notes=textwrap.dedent( + """ + - Note `Environment: ipython` in the system prompt. + - Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>` + """ + ), + ), + UseCase( + title="Zero shot function calling E2E format", + description=textwrap.dedent( + """ + Here is an example of the e2e cycle of tool calls with the model in a muti-step way. + """ + ), + dialogs=[ + [ + RawMessage(role="system", content=system_tool_call()), + RawMessage(role="user", content="What is the weather in SF?"), + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + call_id="cc", + tool_name="get_weather", + arguments={ + "city": "San Francisco", + "metric": "celsius", + }, + ) + ], + ), + RawMessage( + role="tool", + content=json.dumps("25 C"), + ), + ], + ], + notes=textwrap.dedent( + """ + - The output of the function call is provided back to the model as a tool response ( in json format ). + - Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response. + - The model finally summarizes the information from the tool response and returns the result to the user. + """ + ), + tool_prompt_format=ToolPromptFormat.python_list, + ), + UseCase( + title="Prompt format for base models", + description=textwrap.dedent( + """ + For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows + """ + ), + dialogs=[ + TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"), + ], + notes="Same as Llama3.1", + ), + ] diff --git a/llama_stack/models/llama/llama3_2/prompts_vision.py b/llama_stack/models/llama/llama3_2/prompts_vision.py new file mode 100644 index 000000000..c3cfe5e7b --- /dev/null +++ b/llama_stack/models/llama/llama3_2/prompts_vision.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from pathlib import Path + +from llama_models.datatypes import ( + RawMediaItem, + RawMessage, + RawTextItem, +) + +from ..prompt_format import ( + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + # llama3_1_builtin_tool_call_with_image_dialog, + llama3_2_user_assistant_conversation, +) + + +def usecases(): + this_dir = Path(__file__).parent.parent.resolve() + with open(this_dir / "scripts/resources/dog.jpg", "rb") as f: + img = f.read() + + return [ + llama3_2_user_assistant_conversation(), + UseCase( + title="User and assistant conversation with Images", + description="This example shows how to pass and image to the model as part of the messages.", + dialogs=[ + [ + RawMessage( + role="user", + content=[ + RawMediaItem(data=img), + RawTextItem(text="Describe this image in two sentences"), + ], + ) + ], + ], + notes=textwrap.dedent( + """ + - The `<|image|>` tag is used to indicate presence of the image + - The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder + ![Image](mm-model.png) + - Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens + - The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body + - We recommend using a single image in one prompt + """ + ), + ), + UseCase( + title="Builtin and Zero Shot Tool Calling", + description=textwrap.dedent( + """ + Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only. + Use `Environment: ipython` to enable tools. + Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools. + The same builtin tools as Llama3.1 are available, + - code_interpreter (for executing python code) + - brave_search (to search the web) + - wolfram_alpha (for querying wolfram alpha for mathematical questions) + """, + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` before `brave_search` function call. + - The `<|eom_id|>` tag is used to indicate the end of the message. + - Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`. + - Tool Calling does NOT work with images in the prompt as of now. + """ + ), + ), + # UseCase( + # title="Tool Calling for vision models", + # description=textwrap.dedent( + # """ + # While Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only, + # they are not able to do tool calling when prompt contains image inputs (along with text). + # The recommended way would be to separate out the image understanding from the tool calling in successive prompts. + # Here is an example of how that could be done, + # """, + # ), + # dialogs=[llama3_1_builtin_tool_call_with_image_dialog()], + # notes=textwrap.dedent( + # """ + # - Instead of a single prompt (image understanding + tool call), we split into two prompts to achieve the same result. + # """ + # ), + # ), + UseCase( + title="Prompt format for base models", + description=textwrap.dedent( + """ + For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows + """ + ), + dialogs=[ + TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"), + ], + notes="- Same as Llama3.1", + ), + UseCase( + title="Prompt format for base models with Image", + description=textwrap.dedent( + """ + For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image, + """ + ), + dialogs=[ + TextCompletionContent( + content=[ + RawMediaItem(data=img), + RawTextItem(text="If I had to write a haiku for this one"), + ] + ), + ], + notes="- Note the placement of the special tags <|begin_of_text|> and <|image|>", + ), + ] diff --git a/llama_stack/models/llama/llama3_3/prompts.py b/llama_stack/models/llama/llama3_3/prompts.py new file mode 100644 index 000000000..14fd86853 --- /dev/null +++ b/llama_stack/models/llama/llama3_3/prompts.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + # llama3_1_e2e_tool_call_dialog, + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + llama3_1_custom_tool_call_dialog, +) + + +def wolfram_alpha_response(): + return textwrap.dedent( + """ + { + "queryresult": { + "success": true, + "inputstring": "100th decimal of pi", + "pods": [ + { + "title": "Input interpretation", + "subpods": [ + { + "title": "", + "plaintext": "100th digit | \u03c0" + } + ] + }, + { + "title": "Nearby digits", + "subpods": [ + { + "title": "", + "plaintext": "...86208998628034825342117067982148086513282306647093..." + } + ] + }, + { + "title": "Result", + "primary": true, + "subpods": [ + { + "title": "", + "plaintext": "7" + } + ] + } + ] + } + } + """ + ) + + +def usecases() -> List[UseCase | str]: + return [ + textwrap.dedent( + """ + # Llama 3.1 - Prompt Formats + ## Tokens + Here is a list of special tokens that are supported by Llama 3.1: + - `<|begin_of_text|>`: Specifies the start of the prompt + - `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models. + - `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch. + - `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool] + - `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool. + - `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios: + - at the end of a direct interaction between the model and the user + - at the end of multiple interactions between the model and any available tools + This token signals to the executor that the model has finished generating a response. + - `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call. + """ + ), + textwrap.dedent( + """ + There are 4 different roles that are supported by Llama 3.1 + - `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively. + - `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model. + - `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".) + - `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts. + """ + ), + UseCase( + title="Llama 3.1 Base Model", + description="Text completion for Llama 3.1 base model uses this format.", + dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")], + notes="Note start special tag", + ), + "## Llama 3.1 Instruct Model", + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage( + role="user", + content="Answer who are you in the form of jeopardy?", + ), + ] + ], + notes="", + ), + "## Tool Calling Formats", + textwrap.dedent( + """ + The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: + - Brave Search: Tool call to perform web searches. + - Wolfram Alpha: Tool call to perform complex mathematical calculations. + - Code Interpreter: Enables the model to output python code. + """ + ), + UseCase( + title="Builtin Tool Calling", + description=textwrap.dedent( + """ + Here is an example of a conversation using brave search + """ + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model. + - The message body of the assistant response starts with a special tag <|python_tag|> + - As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call. + - The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha` + """ + ), + ), + UseCase( + title="Builtin Code Interpreter", + description="Here is an actual example of model responding with code", + dialogs=[ + [ + RawMessage(role="system", content="Environment: ipython"), + RawMessage( + role="user", + content="Write code to check if number is prime, use that to see if the number 7 is prime", + ), + ], + ], + notes=textwrap.dedent( + """ + - Model starts with <|python_tag|> and continues writing python code that it needs to be executed + - No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it. + """ + ), + ), + UseCase( + title="Built-in tools full interaction", + description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.", + dialogs=[ + [ + RawMessage( + role="system", + content="Environment: ipython\nTools: brave_search, wolfram_alpha\n", + ), + RawMessage(role="user", content="What is the 100th decimal of pi?"), + RawMessage( + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="tool_call_id", + tool_name=BuiltinTool.wolfram_alpha, + arguments={"query": "100th decimal of pi"}, + ) + ], + ), + RawMessage( + role="tool", + content=wolfram_alpha_response(), + ), + ], + ], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` in the assistant response. + - Role is `tool` for the wolfram alpha response that is passed back to the model. + - Final message from assistant has <|eot_id|> tag. + """ + ), + ), + "## Zero shot tool calling", + UseCase( + title="JSON based tool calling", + description=textwrap.dedent( + """ + Llama models can now output custom tool calls from a single message to allow easier tool calling. + The following prompts provide an example of how custom tools can be called from the output of the model. + It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog()], + notes=textwrap.dedent( + """ + - JSON format for providing tools needs name, description and parameters + - Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt + - Instructions for tools added as a user message + - Only single tool calls are supported as of now + """ + ), + ), + # FIXME: This is not working yet as expected + # UseCase( + # title="E2E tool call example", + # description=textwrap.dedent( + # """ + # Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model. + # """ + # ), + # dialogs=[ + # llama3_1_e2e_tool_call_dialog( + # tool_prompt_format=ToolPromptFormat.function_tag + # ) + # ], + # notes="", + # ), + "## Example of a user defined tool calling", + UseCase( + title="`` based tool calling", + description=textwrap.dedent( + """ + Here is an example of how you could also write custom instructions for model to do zero shot tool calling. + In this example, we define a custom tool calling format using the `` tag. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)], + notes=textwrap.dedent( + """ + - In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>` + - Instructions for tools added as a user message + """ + ), + ), + ] diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py new file mode 100644 index 000000000..f42620d57 --- /dev/null +++ b/llama_stack/models/llama/prompt_format.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import json +import textwrap +from pathlib import Path +from typing import List + +from llama_models.datatypes import ( + RawContent, + RawMediaItem, + RawMessage, + RawTextItem, + StopReason, + ToolCall, + ToolPromptFormat, +) +from pydantic import BaseModel, Field + +from .llama3.interface import LLama31Interface +from .llama3.template_data import ( + system_message_builtin_code_only, + system_message_builtin_tools_only, + system_message_custom_tools_only, +) + + +class TextCompletionContent(BaseModel): + content: RawContent = "" + + +class UseCase(BaseModel): + title: str = "" + description: str = "" + dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list) + notes: str = "" + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json + + def md_format(self): + section = textwrap.dedent( + """ + ## {title} + + {description} + + {dialogs_text} + {notes} + + """ + ) + return section.lstrip() + + def dialogs_to_text(self, generator) -> str: + def _code_block(text): + return f"```\n{text}\n```" + + text = "" + for dialog in self.dialogs: + if isinstance(dialog, str): + text += dialog + text += "\n\n" + continue + + elif isinstance(dialog, TextCompletionContent): + input_tokens, output_tokens = generator.text_completion_raw( + dialog.content, + max_gen_len=64, + temperature=0.1, + top_p=0.95, + ) + else: + input_tokens, output_tokens = generator.chat_completion_raw( + dialog, + max_gen_len=512, + temperature=0.0, + top_p=0.95, + tool_prompt_format=self.tool_prompt_format, + ) + text += "##### Input Prompt Format\n" + + # FIXME: This is added to undo the hack in chat_formatter where + # vision tokens are replaced with 128256. + input_tokens = [generator.formatter.vision_token if t == 128256 else t for t in input_tokens] + + text += _code_block(generator.tokenizer.decode(input_tokens)) + # TODO: Figure out if "↵" needs to be added for newlines or end or some indication + text += "\n\n" + text += "##### Model Response Format\n" + text += _code_block(generator.tokenizer.decode(output_tokens)) + text += "\n\n" + + return text + + def to_text(self, generator): + section = self.md_format() + dialogs_text = self.dialogs_to_text(generator) + notes = f"##### Notes\n{self.notes}" if self.notes else "" + section = section.format( + title=self.title, + description=self.description, + dialogs_text=dialogs_text, + notes=notes, + ) + return section + + +def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_tools_only()) + messages += interface.user_message(content="Search the web for the latest price of 1oz gold?") + + return messages + + +def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_code_only()) + messages += interface.user_message( + content="Write code to check if number is prime. Use it to verify if number 7 is prime" + ) + + return messages + + +def llama3_1_builtin_tool_call_with_image_dialog( + tool_prompt_format=ToolPromptFormat.json, +): + this_dir = Path(__file__).parent + with open(this_dir / "llama3/dog.jpg", "rb") as f: + img = f.read() + + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_tools_only()) + messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")]) + messages += interface.assistant_response_messages( + "Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix", + StopReason.end_of_turn, + ) + messages += interface.user_message("Search the web for some food recommendations for the indentified breed") + return messages + + +def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_custom_tools_only()) + messages += interface.user_message(content="Use tools to get latest trending songs") + return messages + + +def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + tool_response = json.dumps(["great song1", "awesome song2", "cool song3"]) + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_custom_tools_only()) + messages += interface.user_message(content="Use tools to get latest trending songs") + messages.append( + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="call_id", + tool_name="trending_songs", + arguments={"n": "10", "genre": "latest"}, + ) + ], + ), + ) + messages.append( + RawMessage( + role="assistant", + content=tool_response, + ) + ) + return messages + + +def llama3_2_user_assistant_conversation(): + return UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage(role="user", content="Who are you?"), + ] + ], + notes="This format is unchanged from Llama3.1", + ) diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py new file mode 100644 index 000000000..6f4a5a885 --- /dev/null +++ b/llama_stack/models/llama/sku_list.py @@ -0,0 +1,1000 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional + +from .datatypes import ( + CheckpointQuantizationFormat, + CoreModelId, + Model, + SamplingParams, + TopPSamplingStrategy, +) + +LLAMA2_VOCAB_SIZE = 32000 +LLAMA3_VOCAB_SIZE = 128256 + + +def resolve_model(descriptor: str) -> Optional[Model]: + for m in all_registered_models(): + if descriptor in (m.descriptor(), m.huggingface_repo): + return m + return None + + +def all_registered_models() -> List[Model]: + return ( + llama2_family() + llama3_family() + llama3_1_family() + llama3_2_family() + llama3_3_family() + safety_models() + ) + + +def recommended_sampling_params() -> SamplingParams: + return SamplingParams( + strategy=TopPSamplingStrategy( + temperature=1.0, + top_p=0.9, + ) + ) + + +def llama2_family() -> List[Model]: + return [ + *llama2_base_models(), + *llama2_instruct_models(), + ] + + +def llama3_family() -> List[Model]: + return [ + *llama3_base_models(), + *llama3_instruct_models(), + ] + + +def llama3_1_family() -> List[Model]: + return [ + *llama3_1_base_models(), + *llama3_1_instruct_models(), + ] + + +def llama3_2_family() -> List[Model]: + return [ + *llama3_2_base_models(), + *llama3_2_instruct_models(), + ] + + +def llama3_3_family() -> List[Model]: + return [ + *llama3_3_instruct_models(), + ] + + +def llama2_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama2_7b, + description="Llama 2 7b model", + huggingface_repo="meta-llama/Llama-2-7b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_13b, + description="Llama 2 13b model", + huggingface_repo="meta-llama/Llama-2-13b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_70b, + description="Llama 2 70b model", + huggingface_repo="meta-llama/Llama-2-70b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_8b, + description="Llama 3 8b model", + huggingface_repo="meta-llama/Llama-3-8B", + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_70b, + description="Llama 3 70b model", + huggingface_repo="meta-llama/Llama-3-70B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_1_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_1_8b, + description="Llama 3.1 8b model", + huggingface_repo="meta-llama/Llama-3.1-8B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_1_70b, + description="Llama 3.1 70b model", + huggingface_repo="meta-llama/Llama-3.1-70B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + variant="bf16-mp8", + description="Llama 3.1 405b model (BF16 weights)", + huggingface_repo="meta-llama/Llama-3.1-405B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + description="Llama 3.1 405b model (FP8 quantized)", + huggingface_repo="meta-llama/Llama-3.1-405B-FP8", + quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + variant="bf16-mp16", + description="Llama 3.1 405b model (BF16 weights for mp16)", + huggingface_repo="meta-llama/Llama-3.1-405B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 16, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=16, + ), + ] + + +def llama3_2_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b, + description="Llama 3.2 1b model", + huggingface_repo="meta-llama/Llama-3.2-1B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b, + description="Llama 3.2 3b model", + huggingface_repo="meta-llama/Llama-3.2-3B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 3072, + "n_layers": 28, + "n_heads": 24, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.0, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_11b_vision, + description="Llama 3.2 11b vision model", + huggingface_repo="meta-llama/Llama-3.2-11B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 448, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_90b_vision, + description="Llama 3.2 90b vision model", + huggingface_repo="meta-llama/Llama-3.2-90B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 20, + }, + pth_file_count=8, + ), + ] + + +def llama2_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama2_7b_chat, + description="Llama 2 7b chat model", + huggingface_repo="meta-llama/Llama-2-7b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_13b_chat, + description="Llama 2 13b chat model", + huggingface_repo="meta-llama/Llama-2-13b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_70b_chat, + description="Llama 2 70b chat model", + huggingface_repo="meta-llama/Llama-2-70b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_8b_instruct, + description="Llama 3 8b instruct model", + huggingface_repo="meta-llama/Llama-3-8B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_70b_instruct, + description="Llama 3 70b instruct model", + huggingface_repo="meta-llama/Llama-3-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_1_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_1_8b_instruct, + description="Llama 3.1 8b instruct model", + huggingface_repo="meta-llama/Llama-3.1-8B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_1_70b_instruct, + description="Llama 3.1 70b instruct model", + huggingface_repo="meta-llama/Llama-3.1-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + variant="bf16-mp8", + description="Llama 3.1 405b instruct model (BF16 weights)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + description="Llama 3.1 405b instruct model (FP8 quantized)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8", + quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + variant="bf16-mp16", + description="Llama 3.1 405b instruct model (BF16 weights for mp16)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 16, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=16, + ), + ] + + +def arch_args_1b() -> dict: + return { + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + } + + +def arch_args_3b() -> dict: + return { + "dim": 3072, + "n_layers": 28, + "n_heads": 24, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.0, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + } + + +def llama3_2_quantized_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + variant="int4-qlora-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 1b INT4 quantized LoRA", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_1b(), + "quantization_args": { + "group_size": 256, + }, + "lora_args": { + "rank": 16, + "scale": 2.0, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + variant="int4-spinquant-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 1b INT4 quantized SpinQuant", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_1b(), + "quantization_args": { + "group_size": 256, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + variant="int4-qlora-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 3b INT4 quantized LoRA", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_3b(), + "quantization_args": { + "group_size": 256, + }, + "lora_args": { + "rank": 16, + "scale": 2.0, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + variant="int4-spinquant-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 3b INT4 quantized SpinQuant", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_3b(), + "quantization_args": { + "group_size": 256, + }, + }, + pth_file_count=1, + ), + ] + + +def llama3_2_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + description="Llama 3.2 1b instruct model", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args=arch_args_1b(), + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + description="Llama 3.2 3b instruct model", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args=arch_args_3b(), + pth_file_count=1, + ), + *llama3_2_quantized_models(), + Model( + core_model_id=CoreModelId.llama3_2_11b_vision_instruct, + description="Llama 3.2 11b vision instruct model", + huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_90b_vision_instruct, + description="Llama 3.2 90b vision instruct model", + huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 20, + }, + pth_file_count=8, + ), + ] + + +def llama3_3_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_3_70b_instruct, + description="Llama 3.3 70b instruct", + huggingface_repo="meta-llama/Llama-3.3-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + ] + + +@lru_cache +def safety_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama_guard_3_11b_vision, + description="Llama Guard v3 11b vision system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_1b, + variant="int4", + description="Llama Guard v3 1b 'int4' quantized system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4", + quantization_format=CheckpointQuantizationFormat.int4, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 12, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "rope_freq_base": 500000.0, + "norm_eps": 1e-05, + "hidden_dim": 6400, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_1b, + description="Llama Guard v3 1b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-1B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_8b, + description="Llama Guard v3 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-8B", + arch_args={ + "dim": 4096, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + "vocab_size": LLAMA3_VOCAB_SIZE, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_8b, + variant="int8", + description="Llama Guard v3 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-8B-INT8", + quantization_format=CheckpointQuantizationFormat.int8, + arch_args={ + "dim": 4096, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + "vocab_size": LLAMA3_VOCAB_SIZE, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_2_8b, + description="Llama Guard v2 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-2-8B", + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + ] + + +@dataclass +class LlamaDownloadInfo: + folder: str + files: List[str] + pth_size: int + + +def llama_meta_net_info(model: Model) -> LlamaDownloadInfo: + """Information needed to download model from llamameta.net""" + + pth_count = model.pth_file_count + if model.core_model_id == CoreModelId.llama3_1_405b: + if pth_count == 16: + folder = "Llama-3.1-405B-MP16" + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + folder = "Llama-3.1-405B" + else: + folder = "Llama-3.1-405B-MP8" + elif model.core_model_id == CoreModelId.llama3_1_405b_instruct: + if pth_count == 16: + folder = "Llama-3.1-405B-Instruct-MP16" + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + folder = "Llama-3.1-405B-Instruct" + else: + folder = "Llama-3.1-405B-Instruct-MP8" + elif model.core_model_id == CoreModelId.llama_guard_3_8b: + if model.quantization_format == CheckpointQuantizationFormat.int8: + folder = "Llama-Guard-3-8B-INT8-HF" + else: + folder = "Llama-Guard-3-8B" + elif model.core_model_id == CoreModelId.llama_guard_2_8b: + folder = "llama-guard-2" + else: + folder = model.huggingface_repo.split("/")[-1] + if "Llama-2" in folder: + folder = folder.lower() + + files = ["checklist.chk"] + if ( + model.core_model_id == CoreModelId.llama_guard_3_8b + and model.quantization_format == CheckpointQuantizationFormat.int8 + ): + files.extend( + [ + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json", + "model.safetensors.index.json", + ] + ) + elif ( + model.core_model_id == CoreModelId.llama_guard_3_1b + and model.quantization_format == CheckpointQuantizationFormat.int4 + ): + files.extend( + [ + "llama_guard_3_1b_pruned_xnnpack.pte", + "example-prompt.txt", + "params.json", + "tokenizer.model", + ] + ) + else: + files.extend( + [ + "tokenizer.model", + "params.json", + ] + ) + if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + files.extend([f"fp8_scales_{i}.pt" for i in range(pth_count)]) + files.extend([f"consolidated.{i:02d}.pth" for i in range(pth_count)]) + + return LlamaDownloadInfo( + folder=folder, + files=files, + pth_size=llama_meta_pth_size(model), + ) + + +# Sadness because Cloudfront rejects our HEAD requests to find Content-Length +def llama_meta_pth_size(model: Model) -> int: + if model.core_model_id not in ( + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_405b_instruct, + ): + return 0 + + if model.pth_file_count == 16: + return 51268302389 + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + return 60903742309 + else: + return 101470976045 diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py new file mode 100644 index 000000000..138f12c8d --- /dev/null +++ b/llama_stack/schema_utils.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union + +# Borrowed from https://github.com/hunyadi/strong_typing/blob/master/strong_typing/core.py + + +class JsonObject: + "Placeholder type for an unrestricted JSON object." + + +class JsonArray: + "Placeholder type for an unrestricted JSON array." + + +# a JSON type with possible `null` values +JsonType = Union[ + None, + bool, + int, + float, + str, + Dict[str, "JsonType"], + List["JsonType"], +] + +# a JSON type that cannot contain `null` values +StrictJsonType = Union[ + bool, + int, + float, + str, + Dict[str, "StrictJsonType"], + List["StrictJsonType"], +] + +# a meta-type that captures the object type in a JSON schema +Schema = Dict[str, JsonType] + + +T = TypeVar("T") + + +def register_schema( + data_type: T, + schema: Optional[Schema] = None, + name: Optional[str] = None, + examples: Optional[List[JsonType]] = None, +) -> T: + """ + Associates a type with a JSON schema definition. + + :param data_type: The type to associate with a JSON schema. + :param schema: The schema to associate the type with. Derived automatically if omitted. + :param name: The name used for looking up the type. Determined automatically if omitted. + :returns: The input type. + """ + return data_type + + +def json_schema_type( + cls: Optional[Type[T]] = None, + *, + schema: Optional[Schema] = None, + examples: Optional[List[JsonType]] = None, +) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + """Decorator to add user-defined schema definition to a class.""" + + def wrap(cls: Type[T]) -> Type[T]: + return register_schema(cls, schema, examples=examples) + + # see if decorator is used as @json_schema_type or @json_schema_type() + if cls is None: + # called with parentheses + return wrap + else: + # called as @json_schema_type without parentheses + return wrap(cls) + + +register_schema(JsonObject, name="JsonObject") +register_schema(JsonArray, name="JsonArray") +register_schema(JsonType, name="JsonType") +register_schema(StrictJsonType, name="StrictJsonType") + + +@dataclass +class WebMethod: + route: Optional[str] = None + public: bool = False + request_examples: Optional[List[Any]] = None + response_examples: Optional[List[Any]] = None + method: Optional[str] = None + + +def webmethod( + route: Optional[str] = None, + method: Optional[str] = None, + public: Optional[bool] = False, + request_examples: Optional[List[Any]] = None, + response_examples: Optional[List[Any]] = None, +) -> Callable[[T], T]: + """ + Decorator that supplies additional metadata to an endpoint operation function. + + :param route: The URL path pattern associated with this operation which path parameters are substituted into. + :param public: True if the operation can be invoked without prior authentication. + :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON. + :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON. + """ + + def wrap(cls: T) -> T: + cls.__webmethod__ = WebMethod( + route=route, + method=method, + public=public or False, + request_examples=request_examples, + response_examples=response_examples, + ) + return cls + + return wrap diff --git a/llama_stack/scripts/generate_prompt_format.py b/llama_stack/scripts/generate_prompt_format.py new file mode 100644 index 000000000..c529b0a5f --- /dev/null +++ b/llama_stack/scripts/generate_prompt_format.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import importlib +from pathlib import Path +from typing import Optional + +import fire + +# from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.llama3.reference_impl.generation import Llama + +THIS_DIR = Path(__file__).parent.resolve() + + +def run_main( + ckpt_dir: str, + module_name: str, + output_path: str, + model_parallel_size: Optional[int] = None, +): + module = importlib.import_module(module_name) + assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function" + tokenizer_path = str(THIS_DIR.parent / "llama3/api/tokenizer.model") + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=512, + max_batch_size=1, + model_parallel_size=model_parallel_size, + ) + + use_cases = module.usecases() + text = "" + for u in use_cases: + if isinstance(u, str): + use_case_text = f"\n{u}\n" + else: + use_case_text = u.to_text(generator) + + text += use_case_text + print(use_case_text) + + text += "Thank You!\n" + + with open(output_path, "w") as f: + f.write(text) + + +def main(): + fire.Fire(run_main) + + +if __name__ == "__main__": + main()