diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d3e382e91..c00ea3040 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,14 +19,14 @@ repos: # - id: no-commit-to-branch # args: ['--branch=main'] -# - repo: https://github.com/Lucas-C/pre-commit-hooks -# rev: v1.5.4 -# hooks: -# - id: insert-license -# files: \.py$|\.sh$ -# args: -# - --license-filepath -# - docs/license_header.txt +- repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.4 + hooks: + - id: insert-license + files: \.py$|\.sh$ + args: + - --license-filepath + - docs/license_header.txt - repo: https://github.com/pycqa/flake8 rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b diff --git a/README.md b/README.md index edf2dd054..cbba37408 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ This repo contains the API specifications for various parts of the Llama Stack. -The Stack consists of toolchain-apis and agentic-apis. +The Stack consists of toolchain-apis and agentic-apis. -The tool chain apis that are covered -- +The tool chain apis that are covered -- - inference / batch inference - post training - reward model scoring @@ -10,7 +10,7 @@ The tool chain apis that are covered -- ## Running FP8 -You need `fbgemm-gpu` package which requires torch >= 2.4.0 (currently only in nightly, but releasing shortly...). +You need `fbgemm-gpu` package which requires torch >= 2.4.0 (currently only in nightly, but releasing shortly...). ```bash ENV=fp8_env @@ -21,19 +21,19 @@ pip3 install -r fp8_requirements.txt ``` -### Generate OpenAPI specs +### Generate OpenAPI specs -Set up virtual environment +Set up virtual environment ``` -python3 -m venv ~/.venv/toolchain/ +python3 -m venv ~/.venv/toolchain/ source ~/.venv/toolchain/bin/activate -with-proxy pip3 install -r requirements.txt +with-proxy pip3 install -r requirements.txt ``` -Run the generate.sh script +Run the generate.sh script ``` cd source && sh generate.sh diff --git a/docs/license_header.txt b/docs/license_header.txt new file mode 100644 index 000000000..cfe551f51 --- /dev/null +++ b/docs/license_header.txt @@ -0,0 +1,5 @@ +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the terms described in the LICENSE file in +the root directory of this source tree. diff --git a/llama_toolchain/__init__.py b/llama_toolchain/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/__init__.py +++ b/llama_toolchain/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/cli/__init__.py b/llama_toolchain/cli/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/cli/__init__.py +++ b/llama_toolchain/cli/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index ba5262c59..0ac64b592 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import argparse import os import textwrap diff --git a/llama_toolchain/cli/inference/__init__.py b/llama_toolchain/cli/inference/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/cli/inference/__init__.py +++ b/llama_toolchain/cli/inference/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/cli/inference/configure.py b/llama_toolchain/cli/inference/configure.py index 5d2596315..bf69c0c2f 100644 --- a/llama_toolchain/cli/inference/configure.py +++ b/llama_toolchain/cli/inference/configure.py @@ -1,10 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import argparse import os -import pkg_resources import textwrap from pathlib import Path +import pkg_resources + from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.utils import DEFAULT_DUMP_DIR @@ -36,21 +49,22 @@ class InferenceConfigure(Subcommand): pass def read_user_inputs(self): - checkpoint_dir = input("Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): ") - model_parallel_size = input("Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): ") - assert model_parallel_size.isdigit() and int(model_parallel_size) in {1, 8}, "model parallel size must be 1 or 8" + checkpoint_dir = input( + "Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): " + ) + model_parallel_size = input( + "Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): " + ) + assert model_parallel_size.isdigit() and int(model_parallel_size) in { + 1, + 8, + }, "model parallel size must be 1 or 8" return checkpoint_dir, model_parallel_size - def write_output_yaml( - self, - checkpoint_dir, - model_parallel_size, - yaml_output_path - ): + def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path): default_conf_path = pkg_resources.resource_filename( - 'llama_toolchain', - 'data/default_inference_config.yaml' + "llama_toolchain", "data/default_inference_config.yaml" ) with open(default_conf_path, "r") as f: yaml_content = f.read() @@ -60,7 +74,7 @@ class InferenceConfigure(Subcommand): model_parallel_size=model_parallel_size, ) - with open(yaml_output_path, 'w') as yaml_file: + with open(yaml_output_path, "w") as yaml_file: yaml_file.write(yaml_content.strip()) print(f"YAML configuration has been written to {yaml_output_path}") @@ -69,8 +83,9 @@ class InferenceConfigure(Subcommand): checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir = os.path.expanduser(checkpoint_dir) - assert Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir(), \ - f"{checkpoint_dir} does not exist or it not a directory" + assert ( + Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir() + ), f"{checkpoint_dir} does not exist or it not a directory" os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml" diff --git a/llama_toolchain/cli/inference/inference.py b/llama_toolchain/cli/inference/inference.py index b83cd5da8..ef5e765a2 100644 --- a/llama_toolchain/cli/inference/inference.py +++ b/llama_toolchain/cli/inference/inference.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import argparse import textwrap diff --git a/llama_toolchain/cli/inference/start.py b/llama_toolchain/cli/inference/start.py index f14d4b7be..bd584aa57 100644 --- a/llama_toolchain/cli/inference/start.py +++ b/llama_toolchain/cli/inference/start.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import argparse import textwrap @@ -40,10 +52,7 @@ class InferenceStart(Subcommand): default=False, ) self.parser.add_argument( - "--config", - type=str, - help="Path to config file", - default="inference" + "--config", type=str, help="Path to config file", default="inference" ) def _run_inference_start_cmd(self, args: argparse.Namespace) -> None: diff --git a/llama_toolchain/cli/llama.py b/llama_toolchain/cli/llama.py index 52a98cfe1..51b7492b8 100644 --- a/llama_toolchain/cli/llama.py +++ b/llama_toolchain/cli/llama.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import argparse from llama_toolchain.cli.download import Download @@ -11,7 +23,7 @@ class LlamaCLIParser: def __init__(self): self.parser = argparse.ArgumentParser( prog="llama", - description="Welcome to the LLama toolchain cli", + description="Welcome to the LLama cli", add_help=True, ) @@ -27,9 +39,8 @@ class LlamaCLIParser: # Import sub-commands from agentic_system if they exist try: - from llama_agentic_system.cli.subcommand_modules import ( - SUBCOMMAND_MODULES, - ) + from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES + for module in SUBCOMMAND_MODULES: module.create(subparsers) diff --git a/llama_toolchain/cli/model/__init__.py b/llama_toolchain/cli/model/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/cli/model/__init__.py +++ b/llama_toolchain/cli/model/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/cli/model/model.py b/llama_toolchain/cli/model/model.py index 6f9e2a2b3..d6efa8127 100644 --- a/llama_toolchain/cli/model/model.py +++ b/llama_toolchain/cli/model/model.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import argparse import textwrap diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index 90f109184..78868d750 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -1,8 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import argparse import textwrap +from llama_models.llama3_1.api.interface import ( + list_jinja_templates, + render_jinja_template, +) + from llama_toolchain.cli.subcommand import Subcommand -from llama_models.llama3_1.api.interface import render_jinja_template, list_jinja_templates class ModelTemplate(Subcommand): diff --git a/llama_toolchain/cli/subcommand.py b/llama_toolchain/cli/subcommand.py index 10bb6667d..cdce0a6c7 100644 --- a/llama_toolchain/cli/subcommand.py +++ b/llama_toolchain/cli/subcommand.py @@ -1,3 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + + class Subcommand: """All llama cli subcommands must inherit from this class""" diff --git a/llama_toolchain/common/__init__.py b/llama_toolchain/common/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/common/__init__.py +++ b/llama_toolchain/common/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/common/deployment_types.py b/llama_toolchain/common/deployment_types.py index 5abd7d991..1dc15fb5a 100644 --- a/llama_toolchain/common/deployment_types.py +++ b/llama_toolchain/common/deployment_types.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum from typing import Dict, Optional diff --git a/llama_toolchain/common/training_types.py b/llama_toolchain/common/training_types.py index c500bc91c..6633ea4e9 100644 --- a/llama_toolchain/common/training_types.py +++ b/llama_toolchain/common/training_types.py @@ -1,5 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from llama_models.llama3_1.api.datatypes import URL -from pydantic import BaseModel +from pydantic import BaseModel class Checkpoint(BaseModel): diff --git a/llama_toolchain/data/default_inference_config.yaml b/llama_toolchain/data/default_inference_config.yaml index 253e0e143..29e40c9b8 100644 --- a/llama_toolchain/data/default_inference_config.yaml +++ b/llama_toolchain/data/default_inference_config.yaml @@ -1,9 +1,14 @@ inference_config: - impl_type: "inline" - inline_config: - checkpoint_type: "pytorch" - checkpoint_dir: {checkpoint_dir}/ - tokenizer_path: {checkpoint_dir}/tokenizer.model - model_parallel_size: {model_parallel_size} + impl_config: + impl_type: "inline" + checkpoint_config: + checkpoint: + checkpoint_type: "pytorch" + checkpoint_dir: {checkpoint_dir}/ + tokenizer_path: {checkpoint_dir}/tokenizer.model + model_parallel_size: {model_parallel_size} + quantization_format: bf16 + quantization: null + torch_seed: null max_seq_len: 2048 max_batch_size: 1 diff --git a/llama_toolchain/dataset/api/__init__.py b/llama_toolchain/dataset/api/__init__.py index 38413ff60..2f98dc6f7 100644 --- a/llama_toolchain/dataset/api/__init__.py +++ b/llama_toolchain/dataset/api/__init__.py @@ -1,2 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403 diff --git a/llama_toolchain/dataset/api/datatypes.py b/llama_toolchain/dataset/api/datatypes.py index 260e68acb..dd22ed5f6 100644 --- a/llama_toolchain/dataset/api/datatypes.py +++ b/llama_toolchain/dataset/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum from typing import Any, Dict, Optional diff --git a/llama_toolchain/dataset/api/endpoints.py b/llama_toolchain/dataset/api/endpoints.py index 023f91259..42fe66b05 100644 --- a/llama_toolchain/dataset/api/endpoints.py +++ b/llama_toolchain/dataset/api/endpoints.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import Protocol from pydantic import BaseModel diff --git a/llama_toolchain/evaluations/api/__init__.py b/llama_toolchain/evaluations/api/__init__.py index 38413ff60..2f98dc6f7 100644 --- a/llama_toolchain/evaluations/api/__init__.py +++ b/llama_toolchain/evaluations/api/__init__.py @@ -1,2 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403 diff --git a/llama_toolchain/evaluations/api/datatypes.py b/llama_toolchain/evaluations/api/datatypes.py index 692664846..ee296174b 100644 --- a/llama_toolchain/evaluations/api/datatypes.py +++ b/llama_toolchain/evaluations/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum from pydantic import BaseModel diff --git a/llama_toolchain/evaluations/api/endpoints.py b/llama_toolchain/evaluations/api/endpoints.py index b9b592313..2ab2c5ee6 100644 --- a/llama_toolchain/evaluations/api/endpoints.py +++ b/llama_toolchain/evaluations/api/endpoints.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import List, Protocol from pydantic import BaseModel diff --git a/llama_toolchain/inference/__init__.py b/llama_toolchain/inference/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/inference/__init__.py +++ b/llama_toolchain/inference/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/inference/api/__init__.py b/llama_toolchain/inference/api/__init__.py index 38413ff60..2f98dc6f7 100644 --- a/llama_toolchain/inference/api/__init__.py +++ b/llama_toolchain/inference/api/__init__.py @@ -1,2 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403 diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/api/config.py index 8c2f160f5..fa6cd95f4 100644 --- a/llama_toolchain/inference/api/config.py +++ b/llama_toolchain/inference/api/config.py @@ -1,14 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from dataclasses import dataclass from enum import Enum from typing import Literal, Optional, Union +from hydra_zen import builds from hydra.core.config_store import ConfigStore +from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat from pydantic import BaseModel, Field from typing_extensions import Annotated from .datatypes import QuantizationConfig -from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat class ImplType(Enum): @@ -72,78 +85,7 @@ class InferenceConfig(BaseModel): ] -# Hydra does not like unions of containers and -# Pydantic does not like Literals -# Adding a simple dataclass with custom coversion -# to config classes - - -@dataclass -class InlineImplHydraConfig: - checkpoint_type: str # "pytorch" / "HF" - # pytorch checkpoint required args - checkpoint_dir: str - tokenizer_path: str - model_parallel_size: int - max_seq_len: int - max_batch_size: int = 1 - quantization: Optional[QuantizationConfig] = None - # TODO: huggingface checkpoint required args - - def convert_to_inline_impl_config(self): - if self.checkpoint_type == "pytorch": - return InlineImplConfig( - checkpoint_config=ModelCheckpointConfig( - checkpoint=PytorchCheckpoint( - checkpoint_type=CheckpointType.pytorch.value, - checkpoint_dir=self.checkpoint_dir, - tokenizer_path=self.tokenizer_path, - model_parallel_size=self.model_parallel_size, - ) - ), - quantization=self.quantization, - max_seq_len=self.max_seq_len, - max_batch_size=self.max_batch_size, - ) - else: - raise NotImplementedError("HF Checkpoint not supported yet") - - -@dataclass -class RemoteImplHydraConfig: - url: str - - def convert_to_remote_impl_config(self): - return RemoteImplConfig( - url=self.url, - ) - - -@dataclass -class InferenceHydraConfig: - impl_type: str - inline_config: Optional[InlineImplHydraConfig] = None - remote_config: Optional[RemoteImplHydraConfig] = None - - def __post_init__(self): - assert self.impl_type in ["inline", "remote"] - if self.impl_type == "inline": - assert self.inline_config is not None - if self.impl_type == "remote": - assert self.remote_config is not None - - def convert_to_inference_config(self): - if self.impl_type == "inline": - inline_config = InlineImplHydraConfig(**self.inline_config) - return InferenceConfig( - impl_config=inline_config.convert_to_inline_impl_config() - ) - elif self.impl_type == "remote": - remote_config = RemoteImplHydraConfig(**self.remote_config) - return InferenceConfig( - impl_config=remote_config.convert_to_remote_impl_config() - ) - +InferenceHydraConfig = builds(InferenceConfig) cs = ConfigStore.instance() cs.store(name="inference_config", node=InferenceHydraConfig) diff --git a/llama_toolchain/inference/api/datatypes.py b/llama_toolchain/inference/api/datatypes.py index 3141a108e..c0e7f96f8 100644 --- a/llama_toolchain/inference/api/datatypes.py +++ b/llama_toolchain/inference/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum from typing import List, Literal, Optional, Union @@ -26,9 +38,7 @@ class Fp8QuantizationConfig(BaseModel): @json_schema_type class Bf16QuantizationConfig(BaseModel): - type: Literal[QuantizationType.bf16.value] = ( - QuantizationType.bf16.value - ) + type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value QuantizationConfig = Annotated[ diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py index c148b0bff..41c0c1589 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/endpoints.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F403 from typing import Optional, Protocol diff --git a/llama_toolchain/inference/api_instance.py b/llama_toolchain/inference/api_instance.py index d39d642be..1650f5097 100644 --- a/llama_toolchain/inference/api_instance.py +++ b/llama_toolchain/inference/api_instance.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .api.config import ImplType, InferenceConfig diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index c798ed6fe..7566a9657 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -1,7 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import asyncio import json +from termcolor import cprint from typing import AsyncGenerator +from urllib.request import getproxies + import fire import httpx @@ -9,12 +24,16 @@ from .api import ( ChatCompletionRequest, ChatCompletionResponseStreamChunk, CompletionRequest, - InstructModel, Inference, + InstructModel, UserMessage, ) from .event_logger import EventLogger +print(getproxies()) +# import sys +# sys.exit(0) + class InferenceClient(Inference): def __init__(self, base_url: str): @@ -53,6 +72,7 @@ async def run_main(host: str, port: int): client = InferenceClient(f"http://{host}:{port}") message = UserMessage(content="hello world, help me out here") + cprint(f"User>{message.content}", "green") req = ChatCompletionRequest( model=InstructModel.llama3_70b_chat, messages=[message], diff --git a/llama_toolchain/inference/event_logger.py b/llama_toolchain/inference/event_logger.py index 71d472ee1..1f09f87bf 100644 --- a/llama_toolchain/inference/event_logger.py +++ b/llama_toolchain/inference/event_logger.py @@ -1,8 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. -from termcolor import cprint -from llama_toolchain.inference.api import ( - ChatCompletionResponseEventType, -) +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + +from termcolor import cprint + +from llama_toolchain.inference.api import ChatCompletionResponseEventType class LogEvent: @@ -30,4 +40,3 @@ class EventLogger: yield LogEvent(event.delta, color="yellow", end="") elif event.event_type == ChatCompletionResponseEventType.complete: yield LogEvent("") - diff --git a/llama_toolchain/inference/generation.py b/llama_toolchain/inference/generation.py index 968c0e4d7..f689b873a 100644 --- a/llama_toolchain/inference/generation.py +++ b/llama_toolchain/inference/generation.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. @@ -48,7 +60,10 @@ class Llama: if checkpoint.checkpoint_type != CheckpointType.pytorch.value: raise NotImplementedError("HuggingFace checkpoints not supported yet") - if config.quantization and config.quantization.type == QuantizationType.fp8.value: + if ( + config.quantization + and config.quantization.type == QuantizationType.fp8.value + ): from .quantization.loader import is_fbgemm_available if not is_fbgemm_available(): @@ -99,17 +114,31 @@ class Llama: model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" - # load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype - torch.set_default_tensor_type(torch.BFloat16Tensor) + fp8 = ( + config.quantization + and config.quantization.type == QuantizationType.fp8.value + ) + + if fp8: + # load on CPU in bf16 so that fp8 conversion does not find an + # unexpected (fp32, e.g.) datatype + torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) + + if fp8: + # load on CPU first since if we are doing fp8, we probably don't + # have enough memory on GPU for bf16 + model.load_state_dict(state_dict, strict=False) if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: torch.set_default_tensor_type(torch.cuda.HalfTensor) + if not fp8: + model.load_state_dict(state_dict, strict=False) + if config.quantization: from .quantization.loader import convert_to_quantized_model diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index 48d15cea1..7f6c64062 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -1,16 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import AsyncGenerator from llama_models.llama3_1.api.datatypes import StopReason -from .api.config import ( - CheckpointQuantizationFormat, - CheckpointType, - InlineImplConfig, -) +from .api.config import InlineImplConfig from .api.datatypes import ( ChatCompletionResponseEvent, ChatCompletionResponseEventType, - QuantizationConfig, ToolCallDelta, ToolCallParseStatus, ) diff --git a/llama_toolchain/inference/model_parallel.py b/llama_toolchain/inference/model_parallel.py index 2d9737a9c..3fceaf787 100644 --- a/llama_toolchain/inference/model_parallel.py +++ b/llama_toolchain/inference/model_parallel.py @@ -1,3 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy from dataclasses import dataclass from functools import partial from typing import Generator, List, Optional @@ -86,7 +99,7 @@ class LlamaModelParallelGenerator: logprobs: bool = False, ) -> Generator: req_obj = InferenceArgs( - messages=messages, + messages=deepcopy(messages), temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, diff --git a/llama_toolchain/inference/parallel_utils.py b/llama_toolchain/inference/parallel_utils.py index daa061792..cc1726c60 100644 --- a/llama_toolchain/inference/parallel_utils.py +++ b/llama_toolchain/inference/parallel_utils.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import multiprocessing import os import pickle diff --git a/llama_toolchain/inference/quantization/fp8_impls.py b/llama_toolchain/inference/quantization/fp8_impls.py index 9cac8bea0..fcdd0ec5d 100644 --- a/llama_toolchain/inference/quantization/fp8_impls.py +++ b/llama_toolchain/inference/quantization/fp8_impls.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. @@ -8,7 +20,7 @@ try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 print("Using efficient FP8 operators in FBGEMM.") -except (ImportError, ModuleNotFoundError): +except ImportError: print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.") raise @@ -57,8 +69,8 @@ def ffn_swiglu( x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded ) - (B, T, D) = x.shape - (HD_L, D_) = w1.shape + (B, T, D) = x.shape # noqa: N806 + (HD_L, D_) = w1.shape # noqa: N806 assert D_ == D assert isinstance(w1, Tensor) @@ -153,8 +165,8 @@ def ffn_swiglu_fp8_dynamic( num_tokens: Optional[Tensor] = None, is_memory_bounded: bool = False, ) -> Tensor: - (B, T, D) = x.shape - HD_L = w1.shape[0] + (B, T, D) = x.shape # noqa: N806 + HD_L = w1.shape[0] # noqa: N806 assert HD_L == w3.shape[0] x1 = fc_fp8_dynamic( x.view(B * T, D), diff --git a/llama_toolchain/inference/quantization/loader.py b/llama_toolchain/inference/quantization/loader.py index f1eccf79e..12af81abc 100644 --- a/llama_toolchain/inference/quantization/loader.py +++ b/llama_toolchain/inference/quantization/loader.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. @@ -9,13 +21,13 @@ import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from llama_models.llama3_1.api.model import Transformer, TransformerBlock -from termcolor import cprint - from llama_toolchain.inference.api.config import ( CheckpointQuantizationFormat, InlineImplConfig, ) from llama_toolchain.inference.api.datatypes import QuantizationType + +from termcolor import cprint from torch import Tensor @@ -24,7 +36,7 @@ def is_fbgemm_available() -> bool: import fbgemm_gpu.experimental.gen_ai # noqa: F401 return True - except (ImportError, ModuleNotFoundError): + except ImportError: return False diff --git a/llama_toolchain/inference/quantization/scripts/quantize_checkpoint.py b/llama_toolchain/inference/quantization/scripts/quantize_checkpoint.py index 6fe66e13f..3a2f5ec25 100644 --- a/llama_toolchain/inference/quantization/scripts/quantize_checkpoint.py +++ b/llama_toolchain/inference/quantization/scripts/quantize_checkpoint.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. diff --git a/llama_toolchain/inference/quantization/scripts/run_quantize_checkpoint.sh b/llama_toolchain/inference/quantization/scripts/run_quantize_checkpoint.sh index a61180907..39491efba 100755 --- a/llama_toolchain/inference/quantization/scripts/run_quantize_checkpoint.sh +++ b/llama_toolchain/inference/quantization/scripts/run_quantize_checkpoint.sh @@ -1,5 +1,17 @@ #!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + set -euo pipefail set -x diff --git a/llama_toolchain/inference/quantization/test_fp8.py b/llama_toolchain/inference/quantization/test_fp8.py index 27b95f65c..21041a5b2 100644 --- a/llama_toolchain/inference/quantization/test_fp8.py +++ b/llama_toolchain/inference/quantization/test_fp8.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. @@ -5,7 +17,7 @@ import unittest import torch -from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode +from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8 from hypothesis import given, settings, strategies as st from torch import Tensor @@ -26,29 +38,25 @@ class FP8Tests(unittest.TestCase): ) def test_fp8_ffn( self, - D: int, + D: int, # noqa HD_L: int, B: int, T: int, UB: float, ) -> None: x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1 - w1 = ( - torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 - ) - w3 = ( - torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 - ) + w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 + w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1 - x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE) - w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE) - w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE) - w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE) + x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE) + w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE) + w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE) + w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE) def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor: - (B, T, D) = x.shape - (HD_L, D_) = w1.shape + (B, T, D) = x.shape # noqa: N806 + (HD_L, D_) = w1.shape # noqa: N806 assert D_ == D x1 = x.view(B * T, D) @ w1.T diff --git a/llama_toolchain/inference/server.py b/llama_toolchain/inference/server.py index fcef995f1..7986aa401 100644 --- a/llama_toolchain/inference/server.py +++ b/llama_toolchain/inference/server.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import asyncio import signal @@ -8,6 +20,7 @@ from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse +from hydra_zen import instantiate from omegaconf import OmegaConf from llama_toolchain.utils import get_default_config_dir, parse_config @@ -43,11 +56,8 @@ async def startup(): global InferenceApiInstance config = get_config() - hydra_config = InferenceHydraConfig( - **OmegaConf.to_container(config["inference_config"], resolve=True) - ) - inference_config = hydra_config.convert_to_inference_config() + inference_config = instantiate(config["inference_config"]) InferenceApiInstance = await get_inference_api_instance( inference_config, ) diff --git a/llama_toolchain/memory/__init__.py b/llama_toolchain/memory/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/memory/__init__.py +++ b/llama_toolchain/memory/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/memory/api/__init__.py b/llama_toolchain/memory/api/__init__.py index 38413ff60..2f98dc6f7 100644 --- a/llama_toolchain/memory/api/__init__.py +++ b/llama_toolchain/memory/api/__init__.py @@ -1,2 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403 diff --git a/llama_toolchain/memory/api/datatypes.py b/llama_toolchain/memory/api/datatypes.py index 0969203f6..7be8af0aa 100644 --- a/llama_toolchain/memory/api/datatypes.py +++ b/llama_toolchain/memory/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import Any, Dict from pydantic import BaseModel diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index 441c1d777..c1d678030 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import List, Protocol from pyopenapi import webmethod diff --git a/llama_toolchain/models/api/endpoints.py b/llama_toolchain/models/api/endpoints.py index 432dc391e..853274b04 100644 --- a/llama_toolchain/models/api/endpoints.py +++ b/llama_toolchain/models/api/endpoints.py @@ -1,9 +1,20 @@ -from typing import Protocol +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. -from pyopenapi import webmethod +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. -from pydantic import BaseModel +from typing import Protocol + +from pydantic import BaseModel # noqa: F401 + +from pyopenapi import webmethod # noqa: F401 -class Models(Protocol): - ... +class Models(Protocol): ... diff --git a/llama_toolchain/post_training/api/__init__.py b/llama_toolchain/post_training/api/__init__.py index 38413ff60..2f98dc6f7 100644 --- a/llama_toolchain/post_training/api/__init__.py +++ b/llama_toolchain/post_training/api/__init__.py @@ -1,2 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403 diff --git a/llama_toolchain/post_training/api/datatypes.py b/llama_toolchain/post_training/api/datatypes.py index 50b491c73..076997625 100644 --- a/llama_toolchain/post_training/api/datatypes.py +++ b/llama_toolchain/post_training/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum from typing import List diff --git a/llama_toolchain/post_training/api/endpoints.py b/llama_toolchain/post_training/api/endpoints.py index 3ec17e01f..fee937e12 100644 --- a/llama_toolchain/post_training/api/endpoints.py +++ b/llama_toolchain/post_training/api/endpoints.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from datetime import datetime from typing import Any, Dict, List, Optional, Protocol diff --git a/llama_toolchain/reward_scoring/api/__init__.py b/llama_toolchain/reward_scoring/api/__init__.py index 38413ff60..2f98dc6f7 100644 --- a/llama_toolchain/reward_scoring/api/__init__.py +++ b/llama_toolchain/reward_scoring/api/__init__.py @@ -1,2 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403 diff --git a/llama_toolchain/reward_scoring/api/datatypes.py b/llama_toolchain/reward_scoring/api/datatypes.py index f53d22861..acd5f4c97 100644 --- a/llama_toolchain/reward_scoring/api/datatypes.py +++ b/llama_toolchain/reward_scoring/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import List from pydantic import BaseModel diff --git a/llama_toolchain/reward_scoring/api/endpoints.py b/llama_toolchain/reward_scoring/api/endpoints.py index 72de43498..adbad32a8 100644 --- a/llama_toolchain/reward_scoring/api/endpoints.py +++ b/llama_toolchain/reward_scoring/api/endpoints.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import List, Protocol, Union from .datatypes import * # noqa: F403 diff --git a/llama_toolchain/safety/__init__.py b/llama_toolchain/safety/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/safety/__init__.py +++ b/llama_toolchain/safety/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/safety/api/__init__.py b/llama_toolchain/safety/api/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/safety/api/__init__.py +++ b/llama_toolchain/safety/api/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/safety/api/config.py b/llama_toolchain/safety/api/config.py index 4858bc53c..49657f1e5 100644 --- a/llama_toolchain/safety/api/config.py +++ b/llama_toolchain/safety/api/config.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import List, Optional from pydantic import BaseModel diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index 27c52337c..0e7f6c7e7 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum from typing import Dict, Optional, Union diff --git a/llama_toolchain/safety/shields/__init__.py b/llama_toolchain/safety/shields/__init__.py index 4dae3b690..875de28d0 100644 --- a/llama_toolchain/safety/shields/__init__.py +++ b/llama_toolchain/safety/shields/__init__.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + # supress warnings and spew of logs from hugging face import transformers diff --git a/llama_toolchain/safety/shields/base.py b/llama_toolchain/safety/shields/base.py index f3fb49bc7..a458b6867 100644 --- a/llama_toolchain/safety/shields/base.py +++ b/llama_toolchain/safety/shields/base.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod from typing import List, Union diff --git a/llama_toolchain/safety/shields/code_scanner.py b/llama_toolchain/safety/shields/code_scanner.py index 8e220b017..a12010282 100644 --- a/llama_toolchain/safety/shields/code_scanner.py +++ b/llama_toolchain/safety/shields/code_scanner.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from codeshield.cs import CodeShield from termcolor import cprint diff --git a/llama_toolchain/safety/shields/contrib/__init__.py b/llama_toolchain/safety/shields/contrib/__init__.py index e69de29bb..f51f20815 100644 --- a/llama_toolchain/safety/shields/contrib/__init__.py +++ b/llama_toolchain/safety/shields/contrib/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. diff --git a/llama_toolchain/safety/shields/contrib/third_party_shield.py b/llama_toolchain/safety/shields/contrib/third_party_shield.py index 29e65dce8..5f53dd19e 100644 --- a/llama_toolchain/safety/shields/contrib/third_party_shield.py +++ b/llama_toolchain/safety/shields/contrib/third_party_shield.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import sys from typing import List @@ -5,7 +17,11 @@ from llama_models.llama3_1.api.datatypes import Message parent_dir = "../.." sys.path.append(parent_dir) -from llama_toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse +from llama_toolchain.safety.shields.base import ( + OnViolationAction, + ShieldBase, + ShieldResponse, +) _INSTANCE = None diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 790ff4def..94f111fb7 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import re from string import Template @@ -100,7 +112,7 @@ class LlamaGuardShield(ShieldBase): def instance( on_violation_action=OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ) -> "LlamaGuardShield": @@ -119,7 +131,7 @@ class LlamaGuardShield(ShieldBase): self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ): @@ -129,6 +141,8 @@ class LlamaGuardShield(ShieldBase): assert model_dir is not None, "Llama Guard model_dir is None" + if excluded_categories is None: + excluded_categories = [] assert len(excluded_categories) == 0 or all( x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" diff --git a/llama_toolchain/safety/shields/prompt_guard.py b/llama_toolchain/safety/shields/prompt_guard.py index ff720da89..142ce3341 100644 --- a/llama_toolchain/safety/shields/prompt_guard.py +++ b/llama_toolchain/safety/shields/prompt_guard.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import auto, Enum from typing import List diff --git a/llama_toolchain/safety/shields/shield_runner.py b/llama_toolchain/safety/shields/shield_runner.py index 27070d424..396e180d8 100644 --- a/llama_toolchain/safety/shields/shield_runner.py +++ b/llama_toolchain/safety/shields/shield_runner.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import asyncio from typing import List @@ -6,7 +18,7 @@ from llama_models.llama3_1.api.datatypes import Message, Role from .base import OnViolationAction, ShieldBase, ShieldResponse -class SafetyException(Exception): +class SafetyException(Exception): # noqa: N818 def __init__(self, response: ShieldResponse): self.response = response super().__init__(response.violation_return_message) diff --git a/llama_toolchain/spec/generate.py b/llama_toolchain/spec/generate.py index 5f7095017..e76f5912e 100644 --- a/llama_toolchain/spec/generate.py +++ b/llama_toolchain/spec/generate.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from datetime import datetime import yaml diff --git a/llama_toolchain/spec/package.sh b/llama_toolchain/spec/package.sh index 854af6ffd..4d5470a57 100644 --- a/llama_toolchain/spec/package.sh +++ b/llama_toolchain/spec/package.sh @@ -1,5 +1,17 @@ #!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + set -euo pipefail TMPDIR=$(mktemp -d) diff --git a/llama_toolchain/spec/post_training_types.py b/llama_toolchain/spec/post_training_types.py index fc7d963cf..b33fd89f0 100644 --- a/llama_toolchain/spec/post_training_types.py +++ b/llama_toolchain/spec/post_training_types.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum from typing import Any, Dict, List diff --git a/llama_toolchain/spec/run_openapi_generator.sh b/llama_toolchain/spec/run_openapi_generator.sh index bb0171fa3..ebce5c458 100644 --- a/llama_toolchain/spec/run_openapi_generator.sh +++ b/llama_toolchain/spec/run_openapi_generator.sh @@ -1,5 +1,17 @@ #!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + set -x PYTHONPATH=/data/users/rsm/llama-models:/data/users/rsm/llama-toolchain:/data/users/rsm/llama-agentic-system:../../../oss-ops:../.. python -m toolchain.spec.generate diff --git a/llama_toolchain/synthetic_data_generation/api/__init__.py b/llama_toolchain/synthetic_data_generation/api/__init__.py index 38413ff60..2f98dc6f7 100644 --- a/llama_toolchain/synthetic_data_generation/api/__init__.py +++ b/llama_toolchain/synthetic_data_generation/api/__init__.py @@ -1,2 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from .datatypes import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403 diff --git a/llama_toolchain/synthetic_data_generation/api/datatypes.py b/llama_toolchain/synthetic_data_generation/api/datatypes.py index fd53a74a3..3cd261e49 100644 --- a/llama_toolchain/synthetic_data_generation/api/datatypes.py +++ b/llama_toolchain/synthetic_data_generation/api/datatypes.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum diff --git a/llama_toolchain/synthetic_data_generation/api/endpoints.py b/llama_toolchain/synthetic_data_generation/api/endpoints.py index 1fbec024a..a11be6111 100644 --- a/llama_toolchain/synthetic_data_generation/api/endpoints.py +++ b/llama_toolchain/synthetic_data_generation/api/endpoints.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from typing import Any, Dict, List, Optional, Protocol from pydantic import BaseModel diff --git a/llama_toolchain/utils.py b/llama_toolchain/utils.py index dbc72b0c8..6fabb6e1f 100644 --- a/llama_toolchain/utils.py +++ b/llama_toolchain/utils.py @@ -1,3 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + import getpass import os from typing import Optional diff --git a/setup.py b/setup.py index a2fdd040f..0d1a685a1 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described found in the +# LICENSE file in the root directory of this source tree. + from setuptools import find_packages, setup # Function to read the requirements.txt file def read_requirements(): - with open('requirements.txt') as req: + with open("requirements.txt") as req: content = req.readlines() return [line.strip() for line in content] @@ -14,11 +26,7 @@ setup( author="Meta Llama", author_email="rsm@meta.com", description="Llama toolchain", - entry_points={ - "console_scripts": [ - 'llama = llama_toolchain.cli.llama:main' - ] - }, + entry_points={"console_scripts": ["llama = llama_toolchain.cli.llama:main"]}, long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/meta-llama/llama-toolchain", @@ -26,5 +34,5 @@ setup( classifiers=[], python_requires=">=3.10", install_requires=read_requirements(), - include_package_data=True + include_package_data=True, )