Initial commit

This commit is contained in:
Ashwin Bharambe 2024-06-25 15:47:57 -07:00 committed by Ashwin Bharambe
commit 5d5acc8ed5
81 changed files with 4458 additions and 0 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import textwrap
from pathlib import Path
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DEFAULT_DUMP_DIR
DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints")
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama download",
description="Download a model from the Hugging Face Hub",
epilog=textwrap.dedent(
"""\
# Here are some examples on how to use this command:
llama download --repo-id meta-llama/Llama-2-7b-hf --hf-token <HF_TOKEN>
llama download --repo-id meta-llama/Llama-2-7b-hf --output-dir /data/my_custom_dir --hf-token <HF_TOKEN>
HF_TOKEN=<HF_TOKEN> llama download --repo-id meta-llama/Llama-2-7b-hf
The output directory will be used to load models and tokenizers for inference.
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_download_cmd)
def _add_arguments(self):
self.parser.add_argument(
"repo_id",
type=str,
help="Name of the repository on Hugging Face Hub eg. llhf/Meta-Llama-3.1-70B-Instruct",
)
self.parser.add_argument(
"--hf-token",
type=str,
required=False,
default=os.getenv("HF_TOKEN", None),
help="Hugging Face API token. Needed for gated models like Llama2. Will also try to read environment variable `HF_TOKEN` as default.",
)
self.parser.add_argument(
"--ignore-patterns",
type=str,
required=False,
default="*.safetensors",
help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring "
"safetensors files to avoid downloading duplicate weights.",
)
def _run_download_cmd(self, args: argparse.Namespace):
model_name = args.repo_id.split("/")[-1]
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
args.repo_id,
local_dir=output_dir,
# "auto" will download to cache_dir and symlink files to local_dir
# avoiding unnecessary duplicate copies
local_dir_use_symlinks="auto",
ignore_patterns=args.ignore_patterns,
token=args.hf_token,
)
except GatedRepoError:
self.parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
self.parser.error(
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
)
except Exception as e:
self.parser.error(e)
print(f"Successfully downloaded model to {true_output_dir}")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import textwrap
from pathlib import Path
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DEFAULT_DUMP_DIR
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs")
class InferenceConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama inference configure",
description="Configure llama toolchain inference configs",
epilog=textwrap.dedent(
"""
Example:
llama inference configure
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_inference_configure_cmd)
def _add_arguments(self):
pass
def read_user_inputs(self):
checkpoint_dir = input(
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
)
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
default_conf_path = pkg_resources.resource_filename(
"llama_toolchain", "data/default_inference_config.yaml"
)
with open(default_conf_path, "r") as f:
yaml_content = f.read()
yaml_content = yaml_content.format(
checkpoint_dir=checkpoint_dir,
model_parallel_size=model_parallel_size,
)
with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}")
def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir)
assert (
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
), f"{checkpoint_dir} does not exist or it not a directory"
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"
self.write_output_yaml(
checkpoint_dir,
model_parallel_size,
yaml_output_path,
)

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_toolchain.cli.inference.configure import InferenceConfigure
from llama_toolchain.cli.inference.start import InferenceStart
from llama_toolchain.cli.subcommand import Subcommand
class InferenceParser(Subcommand):
"""Llama cli for inference apis"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"inference",
prog="llama inference",
description="Run inference on a llama model",
epilog=textwrap.dedent(
"""
Example:
llama inference start <options>
"""
),
)
subparsers = self.parser.add_subparsers(title="inference_subcommands")
# Add sub-commandsa
InferenceStart.create(subparsers)
InferenceConfigure.create(subparsers)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.inference.server import main as inference_server_init
class InferenceStart(Subcommand):
"""Llama Inference cli for starting inference server"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama inference start",
description="Start an inference server",
epilog=textwrap.dedent(
"""
Example:
llama inference start <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_inference_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
self.parser.add_argument(
"--config", type=str, help="Path to config file", default="inference"
)
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
inference_server_init(
config_path=args.config,
port=args.port,
disable_ipv6=args.disable_ipv6,
)

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.download import Download
from llama_toolchain.cli.inference.inference import InferenceParser
from llama_toolchain.cli.model.model import ModelParser
class LlamaCLIParser:
"""Defines CLI parser for Llama CLI"""
def __init__(self):
self.parser = argparse.ArgumentParser(
prog="llama",
description="Welcome to the LLama cli",
add_help=True,
)
# Default command is to print help
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands
Download.create(subparsers)
InferenceParser.create(subparsers)
ModelParser.create(subparsers)
# Import sub-commands from agentic_system if they exist
try:
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
for module in SUBCOMMAND_MODULES:
module.create(subparsers)
except ImportError:
pass
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
def run(self, args: argparse.Namespace) -> None:
args.func(args)
def main():
parser = LlamaCLIParser()
args = parser.parse_args()
parser.run(args)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_toolchain.cli.model.template import ModelTemplate
from llama_toolchain.cli.subcommand import Subcommand
class ModelParser(Subcommand):
"""Llama cli for model interface apis"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"model",
prog="llama model",
description="Describe llama model interfaces",
epilog=textwrap.dedent(
"""
Example:
llama model <subcommand> <options>
"""
),
)
subparsers = self.parser.add_subparsers(title="model_subcommands")
# Add sub-commandsa
# ModelDescribe.create(subparsers)
ModelTemplate.create(subparsers)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_models.llama3_1.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_toolchain.cli.subcommand import Subcommand
class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"template",
prog="llama model template",
description="Show llama model message formats",
epilog=textwrap.dedent(
"""
Example:
llama model template <options>
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_model_template_cmd)
def _add_arguments(self):
self.parser.add_argument(
"-m",
"--model-family",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
self.parser.add_argument(
"--template",
type=str,
help="Usecase template name (system_message, user_message, assistant_message, tool_message)...",
required=False,
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
if args.template:
render_jinja_template(args.template)
else:
list_jinja_templates()

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
class Subcommand:
"""All llama cli subcommands must inherit from this class"""
def __init__(self, *args, **kwargs):
pass
@classmethod
def create(cls, *args, **kwargs):
return cls(*args, **kwargs)
def _add_arguments(self):
pass

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type
class RestAPIMethod(Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
@json_schema_type
class RestAPIExecutionConfig(BaseModel):
url: URL
method: RestAPIMethod
params: Optional[Dict[str, str]] = None
headers: Optional[Dict[str, str]] = None
body: Optional[Dict[str, str]] = None

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel):
iters: int
path: URL
epoch: int

View file

@ -0,0 +1,14 @@
inference_config:
impl_config:
impl_type: "inline"
checkpoint_config:
checkpoint:
checkpoint_type: "pytorch"
checkpoint_dir: {checkpoint_dir}/
tokenizer_path: {checkpoint_dir}/tokenizer.model
model_parallel_size: {model_parallel_size}
quantization_format: bf16
quantization: null
torch_seed: null
max_seq_len: 16384
max_batch_size: 1

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type
class TrainEvalDatasetColumnType(Enum):
dialog = "dialog"
text = "text"
media = "media"
number = "number"
json = "json"
@json_schema_type
class TrainEvalDataset(BaseModel):
"""Dataset to be used for training or evaluating language models."""
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
columns: Dict[str, TrainEvalDatasetColumnType]
content_url: URL
metadata: Optional[Dict[str, Any]] = None

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
from pydantic import BaseModel
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from .datatypes import * # noqa: F403
@json_schema_type
class CreateDatasetRequest(BaseModel):
"""Request to create a dataset."""
uuid: str
dataset: TrainEvalDataset
class Datasets(Protocol):
@webmethod(route="/datasets/create")
def create_dataset(
self,
request: CreateDatasetRequest,
) -> None: ...
@webmethod(route="/datasets/get")
def get_dataset(
self,
dataset_uuid: str,
) -> TrainEvalDataset: ...
@webmethod(route="/datasets/delete")
def delete_dataset(
self,
dataset_uuid: str,
) -> None: ...

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str

View file

@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol
from pydantic import BaseModel
from pyopenapi import webmethod
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
class EvaluateTaskRequestCommon(BaseModel):
job_uuid: str
dataset: TrainEvalDataset
checkpoint: Checkpoint
# generation params
sampling_params: SamplingParams = SamplingParams()
@json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation."""
metrics: List[TextGenerationMetric]
@json_schema_type
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
"""Request to evaluate question answering."""
metrics: List[QuestionAnsweringMetric]
@json_schema_type
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate summarization."""
metrics: List[SummarizationMetric]
class EvaluationJobStatusResponse(BaseModel):
job_uuid: str
@json_schema_type
class EvaluationJobArtifactsResponse(BaseModel):
"""Artifacts of a evaluation job."""
job_uuid: str
class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/")
def post_evaluate_text_generation(
self,
request: EvaluateTextGenerationRequest,
) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/")
def post_evaluate_question_answering(
self,
request: EvaluateQuestionAnsweringRequest,
) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/")
def post_evaluate_summarization(
self,
request: EvaluateSummarizationRequest,
) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Literal, Optional, Union
from hydra.core.config_store import ConfigStore
from hydra_zen import builds
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from .datatypes import QuantizationConfig
@json_schema_type
class ImplType(Enum):
inline = "inline"
remote = "remote"
@json_schema_type
class CheckpointType(Enum):
pytorch = "pytorch"
huggingface = "huggingface"
@json_schema_type
class PytorchCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
CheckpointType.pytorch.value
)
checkpoint_dir: str
tokenizer_path: str
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class HuggingFaceCheckpoint(BaseModel):
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
CheckpointType.huggingface.value
)
repo_id: str # or model_name ?
model_parallel_size: int
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
@json_schema_type
class ModelCheckpointConfig(BaseModel):
checkpoint: Annotated[
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
Field(discriminator="checkpoint_type"),
]
@json_schema_type
class InlineImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int
max_batch_size: int = 1
@json_schema_type
class RemoteImplConfig(BaseModel):
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
url: str = Field(..., description="The URL of the remote module")
@json_schema_type
class InferenceConfig(BaseModel):
impl_config: Annotated[
Union[InlineImplConfig, RemoteImplConfig],
Field(discriminator="impl_type"),
]
InferenceHydraConfig = builds(InferenceConfig)
cs = ConfigStore.instance()
cs.store(name="inference_config", node=InferenceHydraConfig)

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from llama_models.llama3_1.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F403
from typing import Optional, Protocol
# this dependency is annoying and we need a forked up version anyway
from pyopenapi import webmethod
@json_schema_type
class CompletionRequest(BaseModel):
model: PretrainedModel
content: InterleavedTextAttachment
sampling_params: Optional[SamplingParams] = SamplingParams()
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class CompletionResponse(BaseModel):
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
"""streamed completion response."""
delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: PretrainedModel
content_batch: List[InterleavedTextAttachment]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: InstructModel
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events."""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(BaseModel):
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: InstructModel
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
logprobs: Optional[LogProbConfig] = None
quantization_config: Optional[QuantizationConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api.config import ImplType, InferenceConfig
async def get_inference_api_instance(config: InferenceConfig):
if config.impl_config.impl_type == ImplType.inline.value:
from .inference import InferenceImpl
return InferenceImpl(config.impl_config)
from .client import InferenceClient
return InferenceClient(config.impl_config.url)

View file

@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from typing import AsyncGenerator
import fire
import httpx
from termcolor import cprint
from .api import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
InstructModel,
UserMessage,
)
from .event_logger import EventLogger
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
data=request.json(),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int):
client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here")
cprint(f"User>{message.content}", "green")
req = ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
messages=[message],
stream=True,
)
iterator = client.chat_completion(
ChatCompletionRequest(
model=InstructModel.llama3_8b_chat,
messages=[message],
stream=True,
)
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_toolchain.inference.api import ChatCompletionResponseEventType
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator, stream=True):
async for chunk in event_generator:
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")

View file

@ -0,0 +1,319 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3_1.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint
from .api.config import CheckpointType, InlineImplConfig
from .api.datatypes import QuantizationType
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
class Llama:
@staticmethod
def build(config: InlineImplConfig):
"""
Build a Llama instance by initializing and loading a model checkpoint.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
checkpoint = config.checkpoint_config.checkpoint
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = checkpoint.model_parallel_size
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = checkpoint.checkpoint_dir
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
# TODO(ashwin): this block is so we can load internal checkpoints without additional
# fuss. the final code should _not_ have this blurb
if "model" in params:
params = params["model"]
model_args: ModelArgs = ModelArgs(
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
if fp8:
# load on CPU first since if we are doing fp8, we probably don't
# have enough memory on GPU for bf16
model.load_state_dict(state_dict, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if not fp8:
model.load_state_dict(state_dict, strict=False)
if config.quantization:
from .quantization.loader import convert_to_quantized_model
model = convert_to_quantized_model(model, config)
else:
model = model.to("cuda")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
) -> Generator:
params = self.model.params
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
if logprobs
else None
),
)
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
self,
prompt: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False)
yield from self.generate(
model_input=ModelInput(tokens=prompt_tokens),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
include_stop_token=True,
)
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -0,0 +1,159 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3_1.api.datatypes import StopReason
from .api.config import InlineImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ToolCallDelta,
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
)
from .model_parallel import LlamaModelParallelGenerator
class InferenceImpl(Inference):
def __init__(self, config: InlineImplConfig) -> None:
self.config = config
async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def shutdown(self) -> None:
self.generator.stop()
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
messages=request.messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
content=message.content,
tool_calls=message.tool_calls,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from .api.config import InlineImplConfig
from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup
@dataclass
class InferenceArgs:
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
)
def init_model_cb(config: InlineImplConfig):
llama = Llama.build(config)
return ModelRunner(llama)
class LlamaModelParallelGenerator:
"""
This abstraction exists so
- we can run model parallel code without needing to run the CLIs via torchrun
- this also enables use model parallel code within a notebook context.
A Context Manager is used to ensure that the model parallel process is started and stopped
correctly. This does make the ergonomics a little awkward, because it isn't immediately
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: InlineImplConfig):
self.config = config
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
checkpoint = self.config.checkpoint_config.checkpoint
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
def start(self):
self.__enter__()
def stop(self):
self.__exit__(None, None, None)
def __enter__(self):
checkpoint = self.config.checkpoint_config.checkpoint
self.group = ModelParallelProcessGroup(
checkpoint.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import multiprocessing
import os
import pickle
import tempfile
import time
import uuid
from typing import Callable, Generator
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
_END_SENTINEL = "__end_sentinel__"
_CANCEL_SENTINEL = "__cancel_sentinel__"
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
reply_socket = context.socket(zmq.ROUTER)
reply_socket.connect(reply_socket_url)
while True:
client_id, obj = maybe_get_work(reply_socket)
if obj is None:
time.sleep(0.01)
continue
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
break
def send_obj(obj):
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
while True:
tasks = [None]
if mp_rank_0():
client_id, task = maybe_get_work(reply_socket)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task != _CANCEL_SENTINEL:
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
task = tasks[0]
if task is None:
time.sleep(0.1)
else:
try:
out = yield task
if out is None:
break
for obj in out:
updates = [None]
if mp_rank_0():
_, update = maybe_get_work(reply_socket)
if update == _CANCEL_SENTINEL:
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(obj)
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if updates[0] == _CANCEL_SENTINEL:
print("quitting generation loop because request was cancelled")
break
if mp_rank_0():
send_obj(_END_SENTINEL)
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
traceback.print_exc()
if mp_rank_0():
send_obj(e)
if mp_rank_0():
send_obj("DONE")
def maybe_get_work(sock: zmq.Socket):
message = None
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = pickle.loads(obj)
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
return client_id, message
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
) -> None:
model = init_model_cb()
torch.distributed.barrier()
time.sleep(1)
# run the requests co-routine which retrieves requests from the socket
# and sends responses (we provide) back to the caller
req_gen = retrieve_requests(reply_socket_url)
result = None
while True:
try:
task = req_gen.send(result)
if isinstance(task, str) and task == _END_SENTINEL:
break
result = model(task)
except StopIteration:
break
print("[debug] worker process done")
def launch_dist_group(
reply_socket_url: str,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
max_nodes=1,
min_nodes=1,
nproc_per_node=model_parallel_size,
start_method="fork",
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file", "timeout": 90},
max_restarts=0,
monitor_interval=1,
run_id=str(uuid.uuid4()),
)
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
reply_socket_url,
init_model_cb,
)
def start_model_parallel_process(
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
context = zmq.Context()
request_socket = context.socket(zmq.DEALER)
# Binding the request socket to a random port
request_socket.bind("tcp://127.0.0.1:0")
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("fork")
process = ctx.Process(
target=launch_dist_group,
args=(
main_process_url,
model_parallel_size,
init_model_cb,
),
kwargs=kwargs,
)
process.start()
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send_pyobj("READY?")
response = request_socket.recv_pyobj()
print(f"Finished model load {response}")
return request_socket, process
class ModelParallelProcessGroup:
def __init__(
self,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
self.model_parallel_size = model_parallel_size
self.init_model_cb = init_model_cb
self.started = False
self.running = False
def start(self):
assert not self.started, "process group already started"
self.request_socket, self.process = start_model_parallel_process(
self.model_parallel_size,
self.init_model_cb,
)
self.started = True
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(self, request) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send_pyobj(request)
try:
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
if isinstance(obj, Exception):
print(f"[debug] got exception {obj}")
raise obj
yield obj
except GeneratorExit as e:
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
while True:
obj = self.request_socket.recv_pyobj()
if obj == _END_SENTINEL:
break
finally:
self.running = False

View file

@ -0,0 +1,184 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
from typing import Optional, Type
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
from torch import nn, Tensor
class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
return nn.Parameter
@property
def grad_fn(self) -> None:
return None
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
class Fp8RowwiseWeights(
Fp8ScaledWeights,
collections.namedtuple(
"Fp8RowwiseWeights",
["weight", "scale", "shape", "activation_scale_ub"],
),
):
pass
def ffn_swiglu(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (
isinstance(w1, Fp8ScaledWeights)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
assert isinstance(w1, Tensor)
assert isinstance(w3, Tensor)
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
del x1, x2
assert isinstance(w2, Tensor)
return (z @ w2.T).view(B, T, D)
@torch.inference_mode()
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
Args:
w (Tensor): [n, k] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
del w
return Fp8RowwiseWeights(
weight=wq,
scale=w_scale,
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@torch.inference_mode()
def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
Args:
w (Tensor): [n, k] input FP8.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device="cuda",
)
return Fp8RowwiseWeights(
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
scale=w_scale.to(device="cuda"),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)
def fc_fp8_dynamic(
x: Tensor,
w: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
del xq
return y
def ffn_swiglu_fp8_dynamic(
x: Tensor,
w1: Fp8RowwiseWeights,
w3: Fp8RowwiseWeights,
w2: Fp8RowwiseWeights,
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
(B, T, D) = x.shape # noqa: N806
HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0]
x1 = fc_fp8_dynamic(
x.view(B * T, D),
w1,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
x2 = fc_fp8_dynamic(
x.view(B * T, D),
w3,
activation_scale_ub,
num_tokens,
is_memory_bounded,
)
z = torch.nn.functional.silu(x1) * x2
del x1, x2
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
return z_.view(B, T, D)

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,
InlineImplConfig,
)
from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint
from torch import Tensor
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper(
self,
x: Tensor,
):
from .fp8_impls import ffn_swiglu
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer,
config: InlineImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model

View file

@ -0,0 +1,30 @@
#!/bin/bash
if [[ $# -ne 1 ]]; then
echo "Error: Please provide the name of CONDA environment you wish to create"
exit 1
fi
ENV_NAME=$1
set -eu
eval "$(conda shell.bash hook)"
echo "Will build env (or overwrite) named '$ENV_NAME'"
set -x
run_build() {
# Set up the conda environment
yes | conda remove --name $ENV_NAME --all
yes | conda create -n $ENV_NAME python=3.10
conda activate $ENV_NAME
# PT nightly
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
# install dependencies for `llama-agentic-system`
pip install -r fp8_requirements.txt
}
run_build

View file

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import shutil
import sys
from pathlib import Path
from typing import Optional
import fire
import torch
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import Tokenizer
from torch.nn.parameter import Parameter
def main(
ckpt_dir: str,
tokenizer_path: str,
quantized_ckpt_dir: str,
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
""" """
if not os.path.exists(quantized_ckpt_dir):
os.makedirs(quantized_ckpt_dir)
shutil.copy(
os.path.join(ckpt_dir, "params.json"),
os.path.join(quantized_ckpt_dir, "params.json"),
)
shutil.copy(
os.path.join(ckpt_dir, "tokenizer.model"),
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales_path = os.path.join(
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(
quantized_ckpt_dir,
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
)
torch.save(model.state_dict(), ckpt_path)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
set -x
cd $(git rev-parse --show-toplevel)
MASTER_HOST=$1
RUN_ID=$2
CKPT_DIR=$3
QUANT_CKPT_DIR=$4
TOKENIZER_PATH=$5
NNODES=$6
NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \
--rdzv_conf='timeout=120' \
--rdzv_backend=c10d \
--rdzv_endpoint="${MASTER_HOST}:29502" \
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import unittest
import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st
from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):
@settings(deadline=None)
@given(
D=st.sampled_from([4096, 8192]),
HD_L=st.sampled_from([1280, 2560]),
B=st.sampled_from([1, 2]),
T=st.sampled_from([2048, 4096]),
UB=st.sampled_from([1000, 10000]),
)
def test_fp8_ffn(
self,
D: int, # noqa
HD_L: int,
B: int,
T: int,
UB: float,
) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
assert D_ == D
x1 = x.view(B * T, D) @ w1.T
x2 = x.view(B * T, D) @ w3.T
z = torch.nn.functional.silu(x1) * x2
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
# Fake quant
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
v_ref = ref_ffn(x, w1, w3, w2)
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import signal
import fire
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from hydra_zen import instantiate
from llama_toolchain.utils import get_default_config_dir, parse_config
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
from .api_instance import get_inference_api_instance
load_dotenv()
GLOBAL_CONFIG = None
def get_config():
return GLOBAL_CONFIG
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
app = FastAPI()
@app.on_event("startup")
async def startup():
global InferenceApiInstance
config = get_config()
inference_config = instantiate(config["inference_config"])
InferenceApiInstance = await get_inference_api_instance(
inference_config,
)
await InferenceApiInstance.initialize()
@app.on_event("shutdown")
async def shutdown():
global InferenceApiInstance
print("shutting down")
await InferenceApiInstance.shutdown()
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
semaphore = asyncio.Semaphore(1)
@app.post(
"/inference/chat_completion", response_model=ChatCompletionResponseStreamChunk
)
def chat_completion(request: Request, exec_request: ChatCompletionRequest):
if semaphore.locked():
raise HTTPException(
status_code=429,
detail="Only a single concurrent request allowed right now.",
)
async def sse_generator(event_gen):
try:
async for event in event_gen:
yield f"data: {event.json()}\n\n"
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
finally:
semaphore.release()
async def event_gen():
async for event in InferenceApiInstance.chat_completion(exec_request):
yield event
return StreamingResponse(
sse_generator(event_gen()),
media_type="text/event-stream",
)
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
global GLOBAL_CONFIG
config_dir = get_default_config_dir()
GLOBAL_CONFIG = parse_config(config_dir, config_path)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
@json_schema_type
class MemoryBank(BaseModel):
memory_bank_id: str
memory_bank_name: str
@json_schema_type
class MemoryBankDocument(BaseModel):
document_id: str
content: bytes
metadata: Dict[str, Any]
mime_type: str

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol
from pyopenapi import webmethod
from .datatypes import * # noqa: F403
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/create")
def post_create_memory_bank(
self,
bank_id: str,
bank_name: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/drop")
def delete_memory_bank(
self,
bank_id: str,
) -> str: ...
@webmethod(route="/memory_bank/insert")
def post_insert_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/update")
def post_update_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/get")
def get_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/delete")
def delete_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[str]: ...

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
from pydantic import BaseModel # noqa: F401
from pyopenapi import webmethod # noqa: F401
class Models(Protocol): ...

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
batch_size: int
shuffle: bool
n_iters: int
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol
from pydantic import BaseModel, Field
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
from .datatypes import * # noqa: F403
@json_schema_type
class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: PretrainedModel
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
algorithm: FinetuningAlgorithm
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: PostTrainingJobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
resources_allocated: Optional[Dict[str, Any]] = None
checkpoints: List[Checkpoint] = Field(default_factory=list)
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune")
def post_supervised_fine_tune(
self,
request: PostTrainingSFTRequest,
) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize")
def post_preference_optimize(
self,
request: PostTrainingRLHFRequest,
) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post_training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/post_training/job/status")
def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post_training/job/artifacts")
def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from llama_models.llama3_1.api.datatypes import * # noqa: F403
@json_schema_type
class ScoredMessage(BaseModel):
message: Message
score: float
@json_schema_type
class DialogGenerations(BaseModel):
dialog: List[Message]
sampled_generations: List[Message]
@json_schema_type
class ScoredDialogGenerations(BaseModel):
dialog: List[Message]
scored_generations: List[ScoredMessage]

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol, Union
from .datatypes import * # noqa: F403
from pyopenapi import webmethod
@json_schema_type
class RewardScoringRequest(BaseModel):
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
dialog_generations: List[DialogGenerations]
model: RewardModel
@json_schema_type
class RewardScoringResponse(BaseModel):
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
scored_generations: List[ScoredDialogGenerations]
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
def post_score(
self,
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from pydantic import BaseModel
class LlamaGuardShieldConfig(BaseModel):
model_dir: str
excluded_categories: List[str]
disable_input_check: bool = False
disable_output_check: bool = False
class PromptGuardShieldConfig(BaseModel):
model_dir: str
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
class BuiltinShield(Enum):
llama_guard = "llama_guard"
code_scanner_guard = "code_scanner_guard"
third_party_shield = "third_party_shield"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
ShieldType = Union[BuiltinShield, str]
@json_schema_type
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
@json_schema_type
class ShieldDefinition(BaseModel):
shield_type: ShieldType
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class ShieldCall(BaseModel):
call_id: str
shield_type: ShieldType
arguments: Dict[str, str]
@json_schema_type
class ShieldResponse(BaseModel):
shield_type: ShieldType
# TODO(ashwin): clean this up
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List, Union
from llama_models.llama3_1.api.datatypes import Attachment, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
def _to_str(content: Union[str, Attachment]) -> str:
if isinstance(content, str):
return content
elif isinstance(content, Attachment):
return f"File: {str(content.url)}"
else:
raise
if isinstance(message.content, list) or isinstance(message.content, tuple):
return "\n".join([_to_str(c) for c in message.content])
else:
return _to_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from codeshield.cs import CodeShield
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from typing import List
from llama_models.llama3_1.api.datatypes import Message
parent_dir = "../.."
sys.path.append(parent_dir)
from llama_toolchain.safety.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError

View file

@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3_1.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_toolchain.safety.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
CAT_SEX_CRIMES = "Sex Crimes"
CAT_CHILD_EXPLOITATION = "Child Exploitation"
CAT_DEFAMATION = "Defamation"
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
CAT_PRIVACY = "Privacy"
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
CAT_HATE = "Hate"
CAT_SELF_HARM = "Self-Harm"
CAT_SEXUAL_CONTENT = "Sexual Content"
CAT_ELECTIONS = "Elections"
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_VIOLENT_CRIMES: "S1",
CAT_NON_VIOLENT_CRIMES: "S2",
CAT_SEX_CRIMES: "S3",
CAT_CHILD_EXPLOITATION: "S4",
CAT_DEFAMATION: "S5",
CAT_SPECIALIZED_ADVICE: "S6",
CAT_PRIVACY: "S7",
CAT_INTELLECTUAL_PROPERTY: "S8",
CAT_INDISCRIMINATE_WEAPONS: "S9",
CAT_HATE: "S10",
CAT_SELF_HARM: "S11",
CAT_SEXUAL_CONTENT: "S12",
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES,
CAT_SEX_CRIMES,
CAT_CHILD_EXPLOITATION,
CAT_DEFAMATION,
CAT_SPECIALIZED_ADVICE,
CAT_PRIVACY,
CAT_INTELLECTUAL_PROPERTY,
CAT_INDISCRIMINATE_WEAPONS,
CAT_HATE,
CAT_SELF_HARM,
CAT_SEXUAL_CONTENT,
CAT_ELECTIONS,
CAT_CODE_INTERPRETER_ABUSE,
]
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
<BEGIN UNSAFE CONTENT CATEGORIES>
$categories
<END UNSAFE CONTENT CATEGORIES>
"""
PROMPT_CONVERSATION = """
<BEGIN CONVERSATION>
$conversations
<END CONVERSATION>
"""
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
return None
def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
categories = []
for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
continue
categories.append(f"{cat_code}: {cat}.")
return categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
)
raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response

View file

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from .base import OnViolationAction, ShieldBase, ShieldResponse
class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse):
self.response = response
super().__init__(response.violation_return_message)
class ShieldRunnerMixin:
def __init__(
self,
input_shields: List[ShieldBase] = None,
output_shields: List[ShieldBase] = None,
):
self.input_shields = input_shields
self.output_shields = output_shields
async def run_shields(
self, messages: List[Message], shields: List[ShieldBase]
) -> List[ShieldResponse]:
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
results = await asyncio.gather(*[s.run(messages) for s in shields])
for shield, r in zip(shields, results):
if r.is_violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return results

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
class FilteringFunction(Enum):
"""The type of filtering function."""
none = "none"
random = "random"
top_k = "top_k"
top_p = "top_p"
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol
from pydantic import BaseModel
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
@json_schema_type
class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
dialogs: List[Message]
filtering_function: FilteringFunction = FilteringFunction.none
model: Optional[RewardModel] = None
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[ScoredDialogGenerations]
statistics: Optional[Dict[str, Any]] = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
def post_generate(
self,
request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]: ...

64
llama_toolchain/utils.py Normal file
View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import getpass
import os
from typing import Optional
from hydra import compose, initialize, MissingConfigException
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/")
def get_root_directory():
current_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.isfile(os.path.join(current_dir, "__init__.py")):
current_dir = os.path.dirname(current_dir)
return current_dir
def get_default_config_dir():
return os.path.join(DEFAULT_DUMP_DIR, "configs")
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
# Configs can be
# 1. relative paths in {config_dir}/
# 2. or default to file {config_dir}/{user}.yaml
# 3. or ultimate default to {config_dir}/default.yaml
# Get the relative path from the current file to the config directory
current_file_directory = os.path.dirname(os.path.abspath(__file__))
relative_path = os.path.relpath(config_dir, current_file_directory)
GlobalHydra.instance().clear()
initialize(config_path=relative_path)
if config_path is None:
try:
user = getpass.getuser()
config_name = user
except MissingConfigException:
print(f"No user-specific {user}.yaml, using default")
config_name = "default"
else:
config_name = config_path
config_abs_path = os.path.abspath(os.path.join(config_dir, f"{config_name}.yaml"))
print(f"Loading config from : {config_abs_path}")
config = compose(config_name=config_name)
print("Yaml config:")
print("------------------------")
print(OmegaConf.to_yaml(config, resolve=True))
print("------------------------")
return config