Initial commit

This commit is contained in:
Ashwin Bharambe 2024-06-25 15:47:57 -07:00 committed by Ashwin Bharambe
commit 5d5acc8ed5
81 changed files with 4458 additions and 0 deletions

29
.flake8 Normal file
View file

@ -0,0 +1,29 @@
[flake8]
# Suggested config from pytorch that we can adapt
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
# N812 ignored because import torch.nn.functional as F is PyTorch convention
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
# E731 allow usage of assigning lambda expressions
# E701 let black auto-format statements on one line
# E704 let black auto-format statements on one line
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,E701,E704
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B950
optional-ascii-coding = True
exclude =
./.git,
./docs
./build
./scripts,
./venv,
*.pyi
.pre-commit-config.yaml
*.md
.flake8

4
.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
__pycache__
dist
*.egg-info
dev_requirements.txt

53
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,53 @@
exclude: 'build'
default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 6306a48f7dae5861702d573c9c247e4e9498e867
hooks:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: check-added-large-files
args: ['--maxkb=1000']
- id: end-of-file-fixer
exclude: '^(.*\.svg)$'
# Temporarily disabling this
# - id: no-commit-to-branch
# args: ['--branch=main']
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.4
hooks:
- id: insert-license
files: \.py$|\.sh$
args:
- --license-filepath
- docs/license_header.txt
- repo: https://github.com/pycqa/flake8
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
hooks:
- id: flake8
additional_dependencies:
- flake8-bugbear == 22.4.25
- pep8-naming == 0.12.1
- torchfix
args: ['--config=.flake8']
- repo: https://github.com/omnilib/ufmt
rev: v2.7.0
hooks:
- id: ufmt
additional_dependencies:
- black == 24.4.2
- usort == 1.0.8
# - repo: https://github.com/jsh9/pydoclint
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
# hooks:
# - id: pydoclint
# args: [--config=pyproject.toml]

80
CODE_OF_CONDUCT.md Normal file
View file

@ -0,0 +1,80 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

36
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,36 @@
# Contributing to Llama-Models
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## Coding Style
* 2 spaces for indentation rather than tabs
* 80 character line length
* ...
## License
By contributing to Llama, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

48
LICENSE Normal file
View file

@ -0,0 +1,48 @@
LLAMA 3.1 COMMUNITY LICENSE AGREEMENT
Llama 3.1 Version Release Date: July 23, 2024
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Llama Materials set forth herein.
“Documentation” means the specifications, manuals and documentation accompanying Llama 3.1 distributed by Meta at https://llama.meta.com/doc/overview.
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
“Llama 3.1” means the foundational large language models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Meta at https://llama.meta.com/llama-downloads.
“Llama Materials” means, collectively, Meta's proprietary Llama 3.1 and Documentation (and any portion thereof) made available under this Agreement.
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland)
By clicking “I Accept” below or by using or distributing any portion or element of the Llama Materials, you agree to be bound by this Agreement.
1. License Rights and Redistribution.
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta's intellectual property or other rights owned by Meta embodied in the Llama Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Llama Materials.
b. Redistribution and Use.
i. If you distribute or make available the Llama Materials (or any derivative works thereof), or a product or service (including another AI model) that contains any of them, you shall (A) provide a copy of this Agreement with any such Llama Materials; and (B) prominently display “Built with Llama” on a related website, user interface, blogpost, about page, or product documentation. If you use the Llama Materials or any outputs or results of the Llama Materials to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include “Llama” at the beginning of any such AI model name.
ii. If you receive Llama Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you. 
iii. You must retain in all copies of the Llama Materials that you distribute the following attribution notice within a “Notice” text file distributed as a part of such copies: “Llama 3.1 is licensed under the Llama 3.1 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.”
iv. Your use of the Llama Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Llama Materials (available at https://llama.meta.com/llama3_1/use-policy), which is hereby incorporated by reference into this Agreement.
2. Additional Commercial Terms. If, on the Llama 3.1 version release date, the monthly active users of the products or services made available by or for Licensee, or Licensee's affiliates, is greater than 700 million monthly active users in the preceding calendar month, you must request a license from Meta, which Meta may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Meta otherwise expressly grants you such rights.
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
5. Intellectual Property.
a. No trademark licenses are granted under this Agreement, and in connection with the Llama Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Llama Materials or as set forth in this Section 5(a). Meta hereby grants you a license to use “Llama” (the “Mark”) solely as required to comply with the last sentence of Section 1.b.i. You will comply with Meta's brand guidelines (currently accessible at https://about.meta.com/brand/resources/meta/company-brand/). All goodwill arising out of your use of the Mark will inure to the benefit of Meta.
b. Subject to Meta's ownership of Llama Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Llama Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
c. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama Materials or Llama 3.1 outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Llama Materials.
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Llama Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. 

1
MANIFEST.in Normal file
View file

@ -0,0 +1 @@
include llama_toolchain/data/*.yaml

63
README.md Normal file
View file

@ -0,0 +1,63 @@
# llama-toolchain
This repo contains the API specifications for various components of the Llama Stack as well implementations for some of those APIs like model inference.
The Stack consists of toolchain-apis and agentic-apis. This repo contains the toolchain-apis
## Installation and Setup ##
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-toolchain.git
conda create -n toolchain python=3.10
conda activate toolchain
cd llama-toolchain
pip install -e .
```
## Test with cli
We have built a llama cli to make it easy to configure / run parts of the toolchain
```
llama --help
usage: llama [-h] {download,inference,model,agentic_system} ...
Welcome to the LLama cli
options:
-h, --help show this help message and exit
subcommands:
{download,inference,model,agentic_system}
```
There are several subcommands to help get you started
## Start inference server that can run the llama models
```bash
llama inference configure
llama inference start
```
## Test client
```bash
python -m llama_toolchain.inference.client localhost 5000
Initializing client for http://localhost:5000
User>hello world, help me out here
Assistant> Hello! I'd be delighted to help you out. What's on your mind? Do you have a question, a problem, or just need someone to chat with? I'm all ears!
```
## Running FP8
You need `fbgemm-gpu` package which requires torch >= 2.4.0 (currently only in nightly, but releasing shortly...).
```bash
ENV=fp8_env
conda create -n $ENV python=3.10
conda activate $ENV
pip3 install -r fp8_requirements.txt
```

5
docs/license_header.txt Normal file
View file

@ -0,0 +1,5 @@
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the terms described in the LICENSE file in
the root directory of this source tree.

31
fp8_requirements.txt Normal file
View file

@ -0,0 +1,31 @@
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch>=2.4.0.dev20240531,<2.4.1
accelerate
black==24.4.2
codeshield
fairscale
fastapi
fire
flake8
huggingface-hub
httpx
hydra-core
hydra-zen
json-strong-typing
matplotlib
omegaconf
pandas
Pillow
pre-commit
pydantic==1.10.13
pydantic_core==2.18.2
python-dotenv
python-openapi
requests
tiktoken
transformers
ufmt==2.7.0
usort==1.0.8
uvicorn
zmq
fbgemm-gpu==0.8.0rc4

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import textwrap
from pathlib import Path
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DEFAULT_DUMP_DIR
DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints")
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama download",
description="Download a model from the Hugging Face Hub",
epilog=textwrap.dedent(
"""\
# Here are some examples on how to use this command:
llama download --repo-id meta-llama/Llama-2-7b-hf --hf-token <HF_TOKEN>
llama download --repo-id meta-llama/Llama-2-7b-hf --output-dir /data/my_custom_dir --hf-token <HF_TOKEN>
HF_TOKEN=<HF_TOKEN> llama download --repo-id meta-llama/Llama-2-7b-hf
The output directory will be used to load models and tokenizers for inference.
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_download_cmd)
def _add_arguments(self):
self.parser.add_argument(
"repo_id",
type=str,
help="Name of the repository on Hugging Face Hub eg. llhf/Meta-Llama-3.1-70B-Instruct",
)
self.parser.add_argument(
"--hf-token",
type=str,
required=False,
default=os.getenv("HF_TOKEN", None),
help="Hugging Face API token. Needed for gated models like Llama2. Will also try to read environment variable `HF_TOKEN` as default.",
)
self.parser.add_argument(
"--ignore-patterns",
type=str,
required=False,
default="*.safetensors",
help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring "
"safetensors files to avoid downloading duplicate weights.",
)
def _run_download_cmd(self, args: argparse.Namespace):
model_name = args.repo_id.split("/")[-1]
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
args.repo_id,
local_dir=output_dir,
# "auto" will download to cache_dir and symlink files to local_dir
# avoiding unnecessary duplicate copies
local_dir_use_symlinks="auto",
ignore_patterns=args.ignore_patterns,
token=args.hf_token,
)
except GatedRepoError:
self.parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
self.parser.error(
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
)
except Exception as e:
self.parser.error(e)
print(f"Successfully downloaded model to {true_output_dir}")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import textwrap
from pathlib import Path
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DEFAULT_DUMP_DIR
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs")
class InferenceConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama inference configure",
description="Configure llama toolchain inference configs",
epilog=textwrap.dedent(
"""
Example:
llama inference configure
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_inference_configure_cmd)
def _add_arguments(self):
pass
def read_user_inputs(self):
checkpoint_dir = input(
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
)
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
default_conf_path = pkg_resources.resource_filename(
"llama_toolchain", "data/default_inference_config.yaml"
)
with open(default_conf_path, "r") as f:
yaml_content = f.read()
yaml_content = yaml_content.format(
checkpoint_dir=checkpoint_dir,
model_parallel_size=model_parallel_size,
)
with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}")
def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir)
assert (
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
), f"{checkpoint_dir} does not exist or it not a directory"
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"
self.write_output_yaml(
checkpoint_dir,
model_parallel_size,
yaml_output_path,
)

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_toolchain.cli.inference.configure import InferenceConfigure
from llama_toolchain.cli.inference.start import InferenceStart
from llama_toolchain.cli.subcommand import Subcommand
class InferenceParser(Subcommand):
"""Llama cli for inference apis"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"inference",
prog="llama inference",
description="Run inference on a llama model",
epilog=textwrap.dedent(
"""
Example:
llama inference start <options>
"""
),
)
subparsers = self.parser.add_subparsers(title="inference_subcommands")
# Add sub-commandsa
InferenceStart.create(subparsers)
InferenceConfigure.create(subparsers)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.inference.server import main as inference_server_init
class InferenceStart(Subcommand):
"""Llama Inference cli for starting inference server"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama inference start",
description="Start an inference server",
epilog=textwrap.dedent(
"""
Example:
llama inference start <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_inference_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
self.parser.add_argument(
"--config", type=str, help="Path to config file", default="inference"
)
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
inference_server_init(
config_path=args.config,
port=args.port,
disable_ipv6=args.disable_ipv6,
)

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.download import Download
from llama_toolchain.cli.inference.inference import InferenceParser
from llama_toolchain.cli.model.model import ModelParser
class LlamaCLIParser:
"""Defines CLI parser for Llama CLI"""
def __init__(self):
self.parser = argparse.ArgumentParser(
prog="llama",
description="Welcome to the LLama cli",
add_help=True,
)
# Default command is to print help
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands
Download.create(subparsers)
InferenceParser.create(subparsers)
ModelParser.create(subparsers)
# Import sub-commands from agentic_system if they exist
try:
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
for module in SUBCOMMAND_MODULES:
module.create(subparsers)
except ImportError:
pass
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
def run(self, args: argparse.Namespace) -> None:
args.func(args)
def main():
parser = LlamaCLIParser()
args = parser.parse_args()
parser.run(args)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_toolchain.cli.model.template import ModelTemplate
from llama_toolchain.cli.subcommand import Subcommand
class ModelParser(Subcommand):
"""Llama cli for model interface apis"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"model",
prog="llama model",
description="Describe llama model interfaces",
epilog=textwrap.dedent(
"""
Example:
llama model <subcommand> <options>
"""
),
)
subparsers = self.parser.add_subparsers(title="model_subcommands")
# Add sub-commandsa
# ModelDescribe.create(subparsers)
ModelTemplate.create(subparsers)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_models.llama3_1.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_toolchain.cli.subcommand import Subcommand
class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"template",
prog="llama model template",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model template <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-family",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
self.parser.add_argument(
"--template",
type=str,
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
required=False,
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
if args.template:
render_jinja_template(args.template)
else:
list_jinja_templates()

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
class Subcommand:
"""All llama cli subcommands must inherit from this class"""
def __init__(self, *args, **kwargs):
pass
@classmethod
def create(cls, *args, **kwargs):
return cls(*args, **kwargs)
def _add_arguments(self):
pass

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type
class RestAPIMethod(Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
@json_schema_type
class RestAPIExecutionConfig(BaseModel):
url: URL
method: RestAPIMethod
params: Optional[Dict[str, str]] = None
headers: Optional[Dict[str, str]] = None
body: Optional[Dict[str, str]] = None

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel):
iters: int
path: URL
epoch: int

View file

@ -0,0 +1,14 @@
inference_config:
impl_config:
impl_type: "inline"
checkpoint_config:
checkpoint:
checkpoint_type: "pytorch"
checkpoint_dir: {checkpoint_dir}/
tokenizer_path: {checkpoint_dir}/tokenizer.model
model_parallel_size: {model_parallel_size}
quantization_format: bf16
quantization: null
torch_seed: null
max_seq_len: 16384
max_batch_size: 1

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type
class TrainEvalDatasetColumnType(Enum):
dialog = "dialog"
text = "text"
media = "media"
number = "number"
json = "json"
@json_schema_type
class TrainEvalDataset(BaseModel):
"""Dataset to be used for training or evaluating language models."""
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
columns: Dict[str, TrainEvalDatasetColumnType]
content_url: URL
metadata: Optional[Dict[str, Any]] = None

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
from pydantic import BaseModel
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from .datatypes import * # noqa: F403
@json_schema_type
class CreateDatasetRequest(BaseModel):
"""Request to create a dataset."""
uuid: str
dataset: TrainEvalDataset
class Datasets(Protocol):
@webmethod(route="/datasets/create")
def create_dataset(
self,
request: CreateDatasetRequest,
) -> None: ...
@webmethod(route="/datasets/get")
def get_dataset(
self,
dataset_uuid: str,
) -> TrainEvalDataset: ...
@webmethod(route="/datasets/delete")
def delete_dataset(
self,
dataset_uuid: str,
) -> None: ...

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str

View file

@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol
from pydantic import BaseModel
from pyopenapi import webmethod
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
class EvaluateTaskRequestCommon(BaseModel):
job_uuid: str
dataset: TrainEvalDataset
checkpoint: Checkpoint
# generation params
sampling_params: SamplingParams = SamplingParams()
@json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation."""
metrics: List[TextGenerationMetric]
@json_schema_type
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
"""Request to evaluate question answering."""
metrics: List[QuestionAnsweringMetric]
@json_schema_type
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate summarization."""
metrics: List[SummarizationMetric]
class EvaluationJobStatusResponse(BaseModel):
job_uuid: str
@json_schema_type
class EvaluationJobArtifactsResponse(BaseModel):
"""Artifacts of a evaluation job."""
job_uuid: str
class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/")
def post_evaluate_text_generation(
self,
request: EvaluateTextGenerationRequest,
) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/")
def post_evaluate_question_answering(
self,
request: EvaluateQuestionAnsweringRequest,
) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/")
def post_evaluate_summarization(
self,
request: EvaluateSummarizationRequest,
) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Literal, Optional, Union
from hydra.core.config_store import ConfigStore
from hydra_zen import builds
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from .datatypes import QuantizationConfig
@json_schema_type
class ImplType(Enum):
inline = "inline"
remote = "remote"
@json_schema_type
class CheckpointType(Enum):
pytorch = "pytorch"
huggingface = "huggingface"
@json_schema_type
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
)
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class HuggingFaceCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
CheckpointType.huggingface.value
)
repo_id: str # or model_name ?
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class ModelCheckpointConfig(BaseModel):
checkpoint: Annotated[
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
Field(discriminator="checkpoint_type"),
]
@json_schema_type
class InlineImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int
max_batch_size: int = 1
@json_schema_type
class RemoteImplConfig(BaseModel):
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
url: str = Field(..., description="The URL of the remote module")
@json_schema_type
class InferenceConfig(BaseModel):
impl_config: Annotated[
Union[InlineImplConfig, RemoteImplConfig],
Field(discriminator="impl_type"),
]
InferenceHydraConfig = builds(InferenceConfig)
cs = ConfigStore.instance()
cs.store(name="inference_config", node=InferenceHydraConfig)

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from llama_models.llama3_1.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F403
from typing import Optional, Protocol
# this dependency is annoying and we need a forked up version anyway
from pyopenapi import webmethod
@json_schema_type
class CompletionRequest(BaseModel):
model: PretrainedModel
content: InterleavedTextAttachment
sampling_params: Optional[SamplingParams] = SamplingParams()
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class CompletionResponse(BaseModel):
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
"""streamed completion response."""
delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: PretrainedModel
content_batch: List[InterleavedTextAttachment]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: InstructModel
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events."""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(BaseModel):
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: InstructModel
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api.config import ImplType, InferenceConfig
async def get_inference_api_instance(config: InferenceConfig):
if config.impl_config.impl_type == ImplType.inline.value:
from .inference import InferenceImpl
return InferenceImpl(config.impl_config)
from .client import InferenceClient
return InferenceClient(config.impl_config.url)

View file

@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from typing import AsyncGenerator
import fire
import httpx
from termcolor import cprint
from .api import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
InstructModel,
UserMessage,
)
from .event_logger import EventLogger
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
data=request.json(),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int):
client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here")
cprint(f"User>{message.content}", "green")
req = ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
messages=[message],
stream=True,
)
iterator = client.chat_completion(
ChatCompletionRequest(
model=InstructModel.llama3_8b_chat,
messages=[message],
stream=True,
)
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_toolchain.inference.api import ChatCompletionResponseEventType
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator, stream=True):
async for chunk in event_generator:
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")

View file

@ -0,0 +1,319 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3_1.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint
from .api.config import CheckpointType, InlineImplConfig
from .api.datatypes import QuantizationType
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
class Llama:
@staticmethod
def build(config: InlineImplConfig):
"""
Build a Llama instance by initializing and loading a model checkpoint.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
checkpoint = config.checkpoint_config.checkpoint
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = checkpoint.model_parallel_size
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = checkpoint.checkpoint_dir
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
# TODO(ashwin): this block is so we can load internal checkpoints without additional
# fuss. the final code should _not_ have this blurb
if "model" in params:
params = params["model"]
model_args: ModelArgs = ModelArgs(
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
if fp8:
# load on CPU first since if we are doing fp8, we probably don't
# have enough memory on GPU for bf16
model.load_state_dict(state_dict, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if not fp8:
model.load_state_dict(state_dict, strict=False)
if config.quantization:
from .quantization.loader import convert_to_quantized_model
model = convert_to_quantized_model(model, config)
else:
model = model.to("cuda")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
) -> Generator:
params = self.model.params
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
if logprobs
else None
),
)
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
self,
prompt: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False)
yield from self.generate(
model_input=ModelInput(tokens=prompt_tokens),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
include_stop_token=True,
)
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,159 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3_1.api.datatypes import StopReason
from .api.config import InlineImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ToolCallDelta,
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
)
from .model_parallel import LlamaModelParallelGenerator
class InferenceImpl(Inference):
def __init__(self, config: InlineImplConfig) -> None:
self.config = config
async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def shutdown(self) -> None:
self.generator.stop()
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
messages=request.messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
content=message.content,
tool_calls=message.tool_calls,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from .api.config import InlineImplConfig
from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
)
def init_model_cb(config: InlineImplConfig):
llama = Llama.build(config)
return ModelRunner(llama)
class LlamaModelParallelGenerator:
"""
This abstraction exists so
- we can run model parallel code without needing to run the CLIs via torchrun
- this also enables use model parallel code within a notebook context.
A Context Manager is used to ensure that the model parallel process is started and stopped
correctly. This does make the ergonomics a little awkward, because it isn't immediately
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: InlineImplConfig):
self.config = config
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
checkpoint = self.config.checkpoint_config.checkpoint
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
def start(self):
self.__enter__()
def stop(self):
self.__exit__(None, None, None)
def __enter__(self):
checkpoint = self.config.checkpoint_config.checkpoint
self.group = ModelParallelProcessGroup(
checkpoint.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import multiprocessing
import os
import pickle
import tempfile
import time
import uuid
from typing import Callable, Generator
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
_END_SENTINEL = "__end_sentinel__"
_CANCEL_SENTINEL = "__cancel_sentinel__"
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
reply_socket = context.socket(zmq.ROUTER)
reply_socket.connect(reply_socket_url)
while True:
client_id, obj = maybe_get_work(reply_socket)
if obj is None:
time.sleep(0.01)
continue
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
break
def send_obj(obj):
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
while True:
tasks = [None]
if mp_rank_0():
client_id, task = maybe_get_work(reply_socket)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task != _CANCEL_SENTINEL:
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
task = tasks[0]
if task is None:
time.sleep(0.1)
else:
try:
out = yield task
if out is None:
break
for obj in out:
updates = [None]
if mp_rank_0():
_, update = maybe_get_work(reply_socket)
if update == _CANCEL_SENTINEL:
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(obj)
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if updates[0] == _CANCEL_SENTINEL:
print("quitting generation loop because request was cancelled")
break
if mp_rank_0():
send_obj(_END_SENTINEL)
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
traceback.print_exc()
if mp_rank_0():
send_obj(e)
if mp_rank_0():
send_obj("DONE")
def maybe_get_work(sock: zmq.Socket):
message = None
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = pickle.loads(obj)
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
return client_id, message
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
) -> None:
model = init_model_cb()
torch.distributed.barrier()
time.sleep(1)
# run the requests co-routine which retrieves requests from the socket
# and sends responses (we provide) back to the caller
req_gen = retrieve_requests(reply_socket_url)
result = None
while True:
try:
task = req_gen.send(result)
if isinstance(task, str) and task == _END_SENTINEL:
break
result = model(task)
except StopIteration:
break
print("[debug] worker process done")
def launch_dist_group(
reply_socket_url: str,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
max_nodes=1,
min_nodes=1,
nproc_per_node=model_parallel_size,
start_method="fork",
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file", "timeout": 90},
max_restarts=0,
monitor_interval=1,
run_id=str(uuid.uuid4()),
)
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
reply_socket_url,
init_model_cb,
)
def start_model_parallel_process(
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
context = zmq.Context()
request_socket = context.socket(zmq.DEALER)
# Binding the request socket to a random port
request_socket.bind("tcp://127.0.0.1:0")
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("fork")
process = ctx.Process(
target=launch_dist_group,
args=(
main_process_url,
model_parallel_size,
init_model_cb,
),
kwargs=kwargs,
)
process.start()
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send_pyobj("READY?")
response = request_socket.recv_pyobj()
print(f"Finished model load {response}")
return request_socket, process
class ModelParallelProcessGroup:
def __init__(
self,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
self.model_parallel_size = model_parallel_size
self.init_model_cb = init_model_cb
self.started = False
self.running = False
def start(self):
assert not self.started, "process group already started"
self.request_socket, self.process = start_model_parallel_process(
self.model_parallel_size,
self.init_model_cb,
)
self.started = True
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(self, request) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send_pyobj(request)
try:
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
if isinstance(obj, Exception):
print(f"[debug] got exception {obj}")
raise obj
yield obj
except GeneratorExit as e:
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
finally:
self.running = False

View file

@ -0,0 +1,184 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
from typing import Optional, Type
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
from torch import nn, Tensor
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
def ffn_swiglu(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (
isinstance(w1, Fp8ScaledWeights)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
scale=w_scale.to(device="cuda"),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic(
x: Tensor,
w: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
del xq
return y
def ffn_swiglu_fp8_dynamic(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
(B, T, D) = x.shape # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_fp8_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_fp8_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
return z_.view(B, T, D)

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
)
from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint
from torch import Tensor
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer,
config: InlineImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model

View file

@ -0,0 +1,30 @@
#!/bin/bash
if [[ $# -ne 1 ]]; then
echo "Error: Please provide the name of CONDA environment you wish to create"
exit 1
fi
ENV_NAME=$1
set -eu
eval "$(conda shell.bash hook)"
echo "Will build env (or overwrite) named '$ENV_NAME'"
set -x
run_build() {
# Set up the conda environment
yes | conda remove --name $ENV_NAME --all
yes | conda create -n $ENV_NAME python=3.10
conda activate $ENV_NAME
# PT nightly
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
# install dependencies for `llama-agentic-system`
pip install -r fp8_requirements.txt
}
run_build

View file

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import shutil
import sys
from pathlib import Path
from typing import Optional
import fire
import torch
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import Tokenizer
from torch.nn.parameter import Parameter
def main(
ckpt_dir: str,
tokenizer_path: str,
quantized_ckpt_dir: str,
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
""" """
if not os.path.exists(quantized_ckpt_dir):
os.makedirs(quantized_ckpt_dir)
shutil.copy(
os.path.join(ckpt_dir, "params.json"),
os.path.join(quantized_ckpt_dir, "params.json"),
)
shutil.copy(
os.path.join(ckpt_dir, "tokenizer.model"),
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales_path = os.path.join(
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(
quantized_ckpt_dir,
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
)
torch.save(model.state_dict(), ckpt_path)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
set -x
cd $(git rev-parse --show-toplevel)
MASTER_HOST=$1
RUN_ID=$2
CKPT_DIR=$3
QUANT_CKPT_DIR=$4
TOKENIZER_PATH=$5
NNODES=$6
NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \
--rdzv_conf='timeout=120' \
--rdzv_backend=c10d \
--rdzv_endpoint="${MASTER_HOST}:29502" \
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import unittest
import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st
from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):
@settings(deadline=None)
@given(
D=st.sampled_from([4096, 8192]),
HD_L=st.sampled_from([1280, 2560]),
B=st.sampled_from([1, 2]),
T=st.sampled_from([2048, 4096]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_ffn(
self,
D: int, # noqa
HD_L: int,
B: int,
T: int,
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
v_ref = ref_ffn(x, w1, w3, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import signal
import fire
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from hydra_zen import instantiate
from llama_toolchain.utils import get_default_config_dir, parse_config
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
from .api_instance import get_inference_api_instance
load_dotenv()
GLOBAL_CONFIG = None
def get_config():
return GLOBAL_CONFIG
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
app = FastAPI()
@app.on_event("startup")
async def startup():
global InferenceApiInstance
config = get_config()
inference_config = instantiate(config["inference_config"])
InferenceApiInstance = await get_inference_api_instance(
inference_config,
)
await InferenceApiInstance.initialize()
@app.on_event("shutdown")
async def shutdown():
global InferenceApiInstance
print("shutting down")
await InferenceApiInstance.shutdown()
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
semaphore = asyncio.Semaphore(1)
@app.post(
"/inference/chat_completion", response_model=ChatCompletionResponseStreamChunk
)
def chat_completion(request: Request, exec_request: ChatCompletionRequest):
if semaphore.locked():
raise HTTPException(
status_code=429,
detail="Only a single concurrent request allowed right now.",
)
async def sse_generator(event_gen):
try:
async for event in event_gen:
yield f"data: {event.json()}\n\n"
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
finally:
semaphore.release()
async def event_gen():
async for event in InferenceApiInstance.chat_completion(exec_request):
yield event
return StreamingResponse(
sse_generator(event_gen()),
media_type="text/event-stream",
)
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
global GLOBAL_CONFIG
config_dir = get_default_config_dir()
GLOBAL_CONFIG = parse_config(config_dir, config_path)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type
class MemoryBank(BaseModel):
memory_bank_id: str
memory_bank_name: str
@json_schema_type
class MemoryBankDocument(BaseModel):
document_id: str
content: bytes
metadata: Dict[str, Any]
mime_type: str

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol
from pyopenapi import webmethod
from .datatypes import * # noqa: F403
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/create")
def post_create_memory_bank(
self,
bank_id: str,
bank_name: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/drop")
def delete_memory_bank(
self,
bank_id: str,
) -> str: ...
@webmethod(route="/memory_bank/insert")
def post_insert_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/update")
def post_update_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/get")
def get_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/delete")
def delete_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[str]: ...

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
from pydantic import BaseModel # noqa: F401
from pyopenapi import webmethod # noqa: F401
class Models(Protocol): ...

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
batch_size: int
shuffle: bool
n_iters: int
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol
from pydantic import BaseModel, Field
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
from .datatypes import * # noqa: F403
@json_schema_type
class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: PretrainedModel
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
algorithm: FinetuningAlgorithm
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: PostTrainingJobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
resources_allocated: Optional[Dict[str, Any]] = None
checkpoints: List[Checkpoint] = Field(default_factory=list)
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune")
def post_supervised_fine_tune(
self,
request: PostTrainingSFTRequest,
) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize")
def post_preference_optimize(
self,
request: PostTrainingRLHFRequest,
) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post_training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/post_training/job/status")
def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post_training/job/artifacts")
def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from llama_models.llama3_1.api.datatypes import * # noqa: F403
@json_schema_type
class ScoredMessage(BaseModel):
message: Message
score: float
@json_schema_type
class DialogGenerations(BaseModel):
dialog: List[Message]
sampled_generations: List[Message]
@json_schema_type
class ScoredDialogGenerations(BaseModel):
dialog: List[Message]
scored_generations: List[ScoredMessage]

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol, Union
from .datatypes import * # noqa: F403
from pyopenapi import webmethod
@json_schema_type
class RewardScoringRequest(BaseModel):
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
dialog_generations: List[DialogGenerations]
model: RewardModel
@json_schema_type
class RewardScoringResponse(BaseModel):
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
scored_generations: List[ScoredDialogGenerations]
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
def post_score(
self,
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from pydantic import BaseModel
class LlamaGuardShieldConfig(BaseModel):
model_dir: str
excluded_categories: List[str]
disable_input_check: bool = False
disable_output_check: bool = False
class PromptGuardShieldConfig(BaseModel):
model_dir: str
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
class BuiltinShield(Enum):
llama_guard = "llama_guard"
code_scanner_guard = "code_scanner_guard"
third_party_shield = "third_party_shield"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
ShieldType = Union[BuiltinShield, str]
@json_schema_type
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
@json_schema_type
class ShieldDefinition(BaseModel):
shield_type: ShieldType
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class ShieldCall(BaseModel):
call_id: str
shield_type: ShieldType
arguments: Dict[str, str]
@json_schema_type
class ShieldResponse(BaseModel):
shield_type: ShieldType
# TODO(ashwin): clean this up
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List, Union
from llama_models.llama3_1.api.datatypes import Attachment, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
def _to_str(content: Union[str, Attachment]) -> str:
if isinstance(content, str):
return content
elif isinstance(content, Attachment):
return f"File: {str(content.url)}"
else:
raise
if isinstance(message.content, list) or isinstance(message.content, tuple):
return "\n".join([_to_str(c) for c in message.content])
else:
return _to_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from codeshield.cs import CodeShield
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from typing import List
from llama_models.llama3_1.api.datatypes import Message
parent_dir = "../.."
sys.path.append(parent_dir)
from llama_toolchain.safety.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError

View file

@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3_1.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_toolchain.safety.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
CAT_SEX_CRIMES = "Sex Crimes"
CAT_CHILD_EXPLOITATION = "Child Exploitation"
CAT_DEFAMATION = "Defamation"
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
CAT_PRIVACY = "Privacy"
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
CAT_HATE = "Hate"
CAT_SELF_HARM = "Self-Harm"
CAT_SEXUAL_CONTENT = "Sexual Content"
CAT_ELECTIONS = "Elections"
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_VIOLENT_CRIMES: "S1",
CAT_NON_VIOLENT_CRIMES: "S2",
CAT_SEX_CRIMES: "S3",
CAT_CHILD_EXPLOITATION: "S4",
CAT_DEFAMATION: "S5",
CAT_SPECIALIZED_ADVICE: "S6",
CAT_PRIVACY: "S7",
CAT_INTELLECTUAL_PROPERTY: "S8",
CAT_INDISCRIMINATE_WEAPONS: "S9",
CAT_HATE: "S10",
CAT_SELF_HARM: "S11",
CAT_SEXUAL_CONTENT: "S12",
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES,
CAT_SEX_CRIMES,
CAT_CHILD_EXPLOITATION,
CAT_DEFAMATION,
CAT_SPECIALIZED_ADVICE,
CAT_PRIVACY,
CAT_INTELLECTUAL_PROPERTY,
CAT_INDISCRIMINATE_WEAPONS,
CAT_HATE,
CAT_SELF_HARM,
CAT_SEXUAL_CONTENT,
CAT_ELECTIONS,
CAT_CODE_INTERPRETER_ABUSE,
]
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
<BEGIN UNSAFE CONTENT CATEGORIES>
$categories
<END UNSAFE CONTENT CATEGORIES>
"""
PROMPT_CONVERSATION = """
<BEGIN CONVERSATION>
$conversations
<END CONVERSATION>
"""
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
return None
def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
categories = []
for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
continue
categories.append(f"{cat_code}: {cat}.")
return categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
)
raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response

View file

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from .base import OnViolationAction, ShieldBase, ShieldResponse
class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse):
self.response = response
super().__init__(response.violation_return_message)
class ShieldRunnerMixin:
def __init__(
self,
input_shields: List[ShieldBase] = None,
output_shields: List[ShieldBase] = None,
):
self.input_shields = input_shields
self.output_shields = output_shields
async def run_shields(
self, messages: List[Message], shields: List[ShieldBase]
) -> List[ShieldResponse]:
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
results = await asyncio.gather(*[s.run(messages) for s in shields])
for shield, r in zip(shields, results):
if r.is_violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return results

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
class FilteringFunction(Enum):
"""The type of filtering function."""
none = "none"
random = "random"
top_k = "top_k"
top_p = "top_p"
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol
from pydantic import BaseModel
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
@json_schema_type
class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
dialogs: List[Message]
filtering_function: FilteringFunction = FilteringFunction.none
model: Optional[RewardModel] = None
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[ScoredDialogGenerations]
statistics: Optional[Dict[str, Any]] = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
def post_generate(
self,
request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]: ...

64
llama_toolchain/utils.py Normal file
View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import getpass
import os
from typing import Optional
from hydra import compose, initialize, MissingConfigException
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/")
def get_root_directory():
current_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.isfile(os.path.join(current_dir, "__init__.py")):
current_dir = os.path.dirname(current_dir)
return current_dir
def get_default_config_dir():
return os.path.join(DEFAULT_DUMP_DIR, "configs")
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
# Configs can be
# 1. relative paths in {config_dir}/
# 2. or default to file {config_dir}/{user}.yaml
# 3. or ultimate default to {config_dir}/default.yaml
# Get the relative path from the current file to the config directory
current_file_directory = os.path.dirname(os.path.abspath(__file__))
relative_path = os.path.relpath(config_dir, current_file_directory)
GlobalHydra.instance().clear()
initialize(config_path=relative_path)
if config_path is None:
try:
user = getpass.getuser()
config_name = user
except MissingConfigException:
print(f"No user-specific {user}.yaml, using default")
config_name = "default"
else:
config_name = config_path
config_abs_path = os.path.abspath(os.path.join(config_dir, f"{config_name}.yaml"))
print(f"Loading config from : {config_abs_path}")
config = compose(config_name=config_name)
print("Yaml config:")
print("------------------------")
print(OmegaConf.to_yaml(config, resolve=True))
print("------------------------")
return config

3
pyproject.toml Normal file
View file

@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

32
requirements.txt Normal file
View file

@ -0,0 +1,32 @@
accelerate
black==24.4.2
blobfile
codeshield
fairscale
fastapi
fire
flake8
huggingface-hub
httpx
hydra-core
hydra-zen
json-strong-typing
matplotlib
omegaconf
pandas
Pillow
pre-commit
pydantic==1.10.13
pydantic_core==2.18.2
python-dotenv
python-openapi
requests
tiktoken
torch
transformers
ufmt==2.7.0
usort==1.0.8
uvicorn
zmq
llama_models[llama3_1] @ git+ssh://git@github.com/meta-llama/llama-models.git

32
setup.py Normal file
View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from setuptools import find_packages, setup
# Function to read the requirements.txt file
def read_requirements():
with open("requirements.txt") as req:
content = req.readlines()
return [line.strip() for line in content]
setup(
name="llama_toolchain",
version="0.0.0.1",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama toolchain",
entry_points={"console_scripts": ["llama = llama_toolchain.cli.llama:main"]},
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/meta-llama/llama-toolchain",
packages=find_packages(),
classifiers=[],
python_requires=">=3.10",
install_requires=read_requirements(),
include_package_data=True,
)