updating license for toolchain

This commit is contained in:
Ashwin Bharambe 2024-07-22 20:31:42 -07:00
parent 0e2fc9966a
commit 86fff23a9e
74 changed files with 512 additions and 94 deletions

View file

@ -19,14 +19,14 @@ repos:
# - id: no-commit-to-branch # - id: no-commit-to-branch
# args: ['--branch=main'] # args: ['--branch=main']
# - repo: https://github.com/Lucas-C/pre-commit-hooks - repo: https://github.com/Lucas-C/pre-commit-hooks
# rev: v1.5.4 rev: v1.5.4
# hooks: hooks:
# - id: insert-license - id: insert-license
# files: \.py$|\.sh$ files: \.py$|\.sh$
# args: args:
# - --license-filepath - --license-filepath
# - docs/license_header.txt - docs/license_header.txt
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b

5
docs/license_header.txt Normal file
View file

@ -0,0 +1,5 @@
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the terms described found in the
LICENSE file in the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import os import os
import textwrap import textwrap

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,10 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import os import os
import pkg_resources
import textwrap import textwrap
from pathlib import Path from pathlib import Path
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DEFAULT_DUMP_DIR from llama_toolchain.utils import DEFAULT_DUMP_DIR
@ -36,21 +43,22 @@ class InferenceConfigure(Subcommand):
pass pass
def read_user_inputs(self): def read_user_inputs(self):
checkpoint_dir = input("Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): ") checkpoint_dir = input(
model_parallel_size = input("Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): ") "Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
assert model_parallel_size.isdigit() and int(model_parallel_size) in {1, 8}, "model parallel size must be 1 or 8" )
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size return checkpoint_dir, model_parallel_size
def write_output_yaml( def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
self,
checkpoint_dir,
model_parallel_size,
yaml_output_path
):
default_conf_path = pkg_resources.resource_filename( default_conf_path = pkg_resources.resource_filename(
'llama_toolchain', "llama_toolchain", "data/default_inference_config.yaml"
'data/default_inference_config.yaml'
) )
with open(default_conf_path, "r") as f: with open(default_conf_path, "r") as f:
yaml_content = f.read() yaml_content = f.read()
@ -60,7 +68,7 @@ class InferenceConfigure(Subcommand):
model_parallel_size=model_parallel_size, model_parallel_size=model_parallel_size,
) )
with open(yaml_output_path, 'w') as yaml_file: with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip()) yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}") print(f"YAML configuration has been written to {yaml_output_path}")
@ -69,8 +77,9 @@ class InferenceConfigure(Subcommand):
checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir) checkpoint_dir = os.path.expanduser(checkpoint_dir)
assert Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir(), \ assert (
f"{checkpoint_dir} does not exist or it not a directory" Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
), f"{checkpoint_dir} does not exist or it not a directory"
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml" yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import textwrap import textwrap

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import textwrap import textwrap
@ -40,10 +46,7 @@ class InferenceStart(Subcommand):
default=False, default=False,
) )
self.parser.add_argument( self.parser.add_argument(
"--config", "--config", type=str, help="Path to config file", default="inference"
type=str,
help="Path to config file",
default="inference"
) )
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None: def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
from llama_toolchain.cli.download import Download from llama_toolchain.cli.download import Download
@ -27,9 +33,8 @@ class LlamaCLIParser:
# Import sub-commands from agentic_system if they exist # Import sub-commands from agentic_system if they exist
try: try:
from llama_agentic_system.cli.subcommand_modules import ( from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
SUBCOMMAND_MODULES,
)
for module in SUBCOMMAND_MODULES: for module in SUBCOMMAND_MODULES:
module.create(subparsers) module.create(subparsers)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import textwrap import textwrap

View file

@ -1,8 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import textwrap import textwrap
from llama_models.llama3_1.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_models.llama3_1.api.interface import render_jinja_template, list_jinja_templates
class ModelTemplate(Subcommand): class ModelTemplate(Subcommand):

View file

@ -1,3 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
class Subcommand: class Subcommand:
"""All llama cli subcommands must inherit from this class""" """All llama cli subcommands must inherit from this class"""

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from llama_models.llama3_1.api.datatypes import URL from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import Protocol from typing import Protocol
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Protocol from typing import List, Protocol
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403

View file

@ -1,14 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from hydra.core.config_store import ConfigStore from hydra.core.config_store import ConfigStore
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from .datatypes import QuantizationConfig from .datatypes import QuantizationConfig
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
class ImplType(Enum): class ImplType(Enum):

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
@ -26,9 +32,7 @@ class Fp8QuantizationConfig(BaseModel):
@json_schema_type @json_schema_type
class Bf16QuantizationConfig(BaseModel): class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = ( type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationType.bf16.value
)
QuantizationConfig = Annotated[ QuantizationConfig = Annotated[

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
from typing import Optional, Protocol from typing import Optional, Protocol

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .api.config import ImplType, InferenceConfig from .api.config import ImplType, InferenceConfig

View file

@ -1,7 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import asyncio import asyncio
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from urllib.request import getproxies
import fire import fire
import httpx import httpx
@ -9,12 +17,16 @@ from .api import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionRequest, CompletionRequest,
InstructModel,
Inference, Inference,
InstructModel,
UserMessage, UserMessage,
) )
from .event_logger import EventLogger from .event_logger import EventLogger
print(getproxies())
# import sys
# sys.exit(0)
class InferenceClient(Inference): class InferenceClient(Inference):
def __init__(self, base_url: str): def __init__(self, base_url: str):

View file

@ -1,8 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from termcolor import cprint from termcolor import cprint
from llama_toolchain.inference.api import (
ChatCompletionResponseEventType, from llama_toolchain.inference.api import ChatCompletionResponseEventType
)
class LogEvent: class LogEvent:
@ -30,4 +34,3 @@ class EventLogger:
yield LogEvent(event.delta, color="yellow", end="") yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete: elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("") yield LogEvent("")

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

View file

@ -1,16 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.llama3_1.api.datatypes import StopReason from llama_models.llama3_1.api.datatypes import StopReason
from .api.config import ( from .api.config import InlineImplConfig
CheckpointQuantizationFormat,
CheckpointType,
InlineImplConfig,
)
from .api.datatypes import ( from .api.datatypes import (
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
QuantizationConfig,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing import multiprocessing
import os import os
import pickle import pickle

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -8,7 +14,7 @@ try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.") print("Using efficient FP8 operators in FBGEMM.")
except (ImportError, ModuleNotFoundError): except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.") print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise raise
@ -57,8 +63,8 @@ def ffn_swiglu(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
) )
(B, T, D) = x.shape (B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape (HD_L, D_) = w1.shape # noqa: N806
assert D_ == D assert D_ == D
assert isinstance(w1, Tensor) assert isinstance(w1, Tensor)
@ -153,8 +159,8 @@ def ffn_swiglu_fp8_dynamic(
num_tokens: Optional[Tensor] = None, num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False, is_memory_bounded: bool = False,
) -> Tensor: ) -> Tensor:
(B, T, D) = x.shape (B, T, D) = x.shape # noqa: N806
HD_L = w1.shape[0] HD_L = w1.shape[0] # noqa: N806
assert HD_L == w3.shape[0] assert HD_L == w3.shape[0]
x1 = fc_fp8_dynamic( x1 = fc_fp8_dynamic(
x.view(B * T, D), x.view(B * T, D),

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -9,13 +15,13 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from termcolor import cprint
from llama_toolchain.inference.api.config import ( from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
InlineImplConfig, InlineImplConfig,
) )
from llama_toolchain.inference.api.datatypes import QuantizationType from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint
from torch import Tensor from torch import Tensor
@ -24,7 +30,7 @@ def is_fbgemm_available() -> bool:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True return True
except (ImportError, ModuleNotFoundError): except ImportError:
return False return False

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

View file

@ -1,5 +1,11 @@
#!/bin/bash #!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
set -euo pipefail set -euo pipefail
set -x set -x

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -5,7 +11,7 @@ import unittest
import torch import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st from hypothesis import given, settings, strategies as st
from torch import Tensor from torch import Tensor
@ -26,29 +32,25 @@ class FP8Tests(unittest.TestCase):
) )
def test_fp8_ffn( def test_fp8_ffn(
self, self,
D: int, D: int, # noqa
HD_L: int, HD_L: int,
B: int, B: int,
T: int, T: int,
UB: float, UB: float,
) -> None: ) -> None:
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1 x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w1 = ( w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w3 = (
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
)
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1 w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE) x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE) w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE) w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE) w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor: def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
(B, T, D) = x.shape (B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape (HD_L, D_) = w1.shape # noqa: N806
assert D_ == D assert D_ == D
x1 = x.view(B * T, D) @ w1.T x1 = x.view(B * T, D) @ w1.T

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import asyncio import asyncio
import signal import signal

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Protocol from typing import List, Protocol
from pyopenapi import webmethod from pyopenapi import webmethod

View file

@ -1,9 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import Protocol from typing import Protocol
from pyopenapi import webmethod from pydantic import BaseModel # noqa: F401
from pydantic import BaseModel from pyopenapi import webmethod # noqa: F401
class Models(Protocol): class Models(Protocol): ...
...

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import List from typing import List

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Protocol, Union from typing import List, Protocol, Union
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Dict, Optional, Union from typing import Dict, Optional, Union

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
# supress warnings and spew of logs from hugging face # supress warnings and spew of logs from hugging face
import transformers import transformers

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union from typing import List, Union

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from codeshield.cs import CodeShield from codeshield.cs import CodeShield
from termcolor import cprint from termcolor import cprint

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import sys import sys
from typing import List from typing import List
@ -5,7 +11,11 @@ from llama_models.llama3_1.api.datatypes import Message
parent_dir = "../.." parent_dir = "../.."
sys.path.append(parent_dir) sys.path.append(parent_dir)
from llama_toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse from llama_toolchain.safety.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None _INSTANCE = None

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import re import re
from string import Template from string import Template
@ -100,7 +106,7 @@ class LlamaGuardShield(ShieldBase):
def instance( def instance(
on_violation_action=OnViolationAction.RAISE, on_violation_action=OnViolationAction.RAISE,
model_dir: str = None, model_dir: str = None,
excluded_categories: List[str] = [], excluded_categories: List[str] = None,
disable_input_check: bool = False, disable_input_check: bool = False,
disable_output_check: bool = False, disable_output_check: bool = False,
) -> "LlamaGuardShield": ) -> "LlamaGuardShield":
@ -119,7 +125,7 @@ class LlamaGuardShield(ShieldBase):
self, self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE, on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None, model_dir: str = None,
excluded_categories: List[str] = [], excluded_categories: List[str] = None,
disable_input_check: bool = False, disable_input_check: bool = False,
disable_output_check: bool = False, disable_output_check: bool = False,
): ):
@ -129,6 +135,8 @@ class LlamaGuardShield(ShieldBase):
assert model_dir is not None, "Llama Guard model_dir is None" assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all( assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import auto, Enum from enum import auto, Enum
from typing import List from typing import List

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import asyncio import asyncio
from typing import List from typing import List
@ -6,7 +12,7 @@ from llama_models.llama3_1.api.datatypes import Message, Role
from .base import OnViolationAction, ShieldBase, ShieldResponse from .base import OnViolationAction, ShieldBase, ShieldResponse
class SafetyException(Exception): class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse): def __init__(self, response: ShieldResponse):
self.response = response self.response = response
super().__init__(response.violation_return_message) super().__init__(response.violation_return_message)

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from datetime import datetime from datetime import datetime
import yaml import yaml

View file

@ -1,5 +1,11 @@
#!/bin/bash #!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
set -euo pipefail set -euo pipefail
TMPDIR=$(mktemp -d) TMPDIR=$(mktemp -d)

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List from typing import Any, Dict, List

View file

@ -1,5 +1,11 @@
#!/bin/bash #!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
set -x set -x
PYTHONPATH=/data/users/rsm/llama-models:/data/users/rsm/llama-toolchain:/data/users/rsm/llama-agentic-system:../../../oss-ops:../.. python -m toolchain.spec.generate PYTHONPATH=/data/users/rsm/llama-models:/data/users/rsm/llama-toolchain:/data/users/rsm/llama-agentic-system:../../../oss-ops:../.. python -m toolchain.spec.generate

View file

@ -1,2 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from .datatypes import * # noqa: F401 F403 from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403 from .endpoints import * # noqa: F401 F403

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum from enum import Enum

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import getpass import getpass
import os import os
from typing import Optional from typing import Optional

View file

@ -1,9 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from setuptools import find_packages, setup from setuptools import find_packages, setup
# Function to read the requirements.txt file # Function to read the requirements.txt file
def read_requirements(): def read_requirements():
with open('requirements.txt') as req: with open("requirements.txt") as req:
content = req.readlines() content = req.readlines()
return [line.strip() for line in content] return [line.strip() for line in content]
@ -14,11 +20,7 @@ setup(
author="Meta Llama", author="Meta Llama",
author_email="rsm@meta.com", author_email="rsm@meta.com",
description="Llama toolchain", description="Llama toolchain",
entry_points={ entry_points={"console_scripts": ["llama = llama_toolchain.cli.llama:main"]},
"console_scripts": [
'llama = llama_toolchain.cli.llama:main'
]
},
long_description=open("README.md").read(), long_description=open("README.md").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/meta-llama/llama-toolchain", url="https://github.com/meta-llama/llama-toolchain",
@ -26,5 +28,5 @@ setup(
classifiers=[], classifiers=[],
python_requires=">=3.10", python_requires=">=3.10",
install_requires=read_requirements(), install_requires=read_requirements(),
include_package_data=True include_package_data=True,
) )