mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
updating license for toolchain
This commit is contained in:
parent
0e2fc9966a
commit
86fff23a9e
74 changed files with 512 additions and 94 deletions
|
@ -19,14 +19,14 @@ repos:
|
||||||
# - id: no-commit-to-branch
|
# - id: no-commit-to-branch
|
||||||
# args: ['--branch=main']
|
# args: ['--branch=main']
|
||||||
|
|
||||||
# - repo: https://github.com/Lucas-C/pre-commit-hooks
|
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||||
# rev: v1.5.4
|
rev: v1.5.4
|
||||||
# hooks:
|
hooks:
|
||||||
# - id: insert-license
|
- id: insert-license
|
||||||
# files: \.py$|\.sh$
|
files: \.py$|\.sh$
|
||||||
# args:
|
args:
|
||||||
# - --license-filepath
|
- --license-filepath
|
||||||
# - docs/license_header.txt
|
- docs/license_header.txt
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/flake8
|
- repo: https://github.com/pycqa/flake8
|
||||||
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
|
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
|
||||||
|
|
16
README.md
16
README.md
|
@ -1,7 +1,7 @@
|
||||||
This repo contains the API specifications for various parts of the Llama Stack.
|
This repo contains the API specifications for various parts of the Llama Stack.
|
||||||
The Stack consists of toolchain-apis and agentic-apis.
|
The Stack consists of toolchain-apis and agentic-apis.
|
||||||
|
|
||||||
The tool chain apis that are covered --
|
The tool chain apis that are covered --
|
||||||
- inference / batch inference
|
- inference / batch inference
|
||||||
- post training
|
- post training
|
||||||
- reward model scoring
|
- reward model scoring
|
||||||
|
@ -10,7 +10,7 @@ The tool chain apis that are covered --
|
||||||
|
|
||||||
## Running FP8
|
## Running FP8
|
||||||
|
|
||||||
You need `fbgemm-gpu` package which requires torch >= 2.4.0 (currently only in nightly, but releasing shortly...).
|
You need `fbgemm-gpu` package which requires torch >= 2.4.0 (currently only in nightly, but releasing shortly...).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ENV=fp8_env
|
ENV=fp8_env
|
||||||
|
@ -21,19 +21,19 @@ pip3 install -r fp8_requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Generate OpenAPI specs
|
### Generate OpenAPI specs
|
||||||
|
|
||||||
Set up virtual environment
|
Set up virtual environment
|
||||||
|
|
||||||
```
|
```
|
||||||
python3 -m venv ~/.venv/toolchain/
|
python3 -m venv ~/.venv/toolchain/
|
||||||
source ~/.venv/toolchain/bin/activate
|
source ~/.venv/toolchain/bin/activate
|
||||||
|
|
||||||
with-proxy pip3 install -r requirements.txt
|
with-proxy pip3 install -r requirements.txt
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Run the generate.sh script
|
Run the generate.sh script
|
||||||
|
|
||||||
```
|
```
|
||||||
cd source && sh generate.sh
|
cd source && sh generate.sh
|
||||||
|
|
5
docs/license_header.txt
Normal file
5
docs/license_header.txt
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
This source code is licensed under the terms described found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,10 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pkg_resources
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
||||||
|
|
||||||
|
@ -36,21 +43,22 @@ class InferenceConfigure(Subcommand):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def read_user_inputs(self):
|
def read_user_inputs(self):
|
||||||
checkpoint_dir = input("Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): ")
|
checkpoint_dir = input(
|
||||||
model_parallel_size = input("Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): ")
|
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
|
||||||
assert model_parallel_size.isdigit() and int(model_parallel_size) in {1, 8}, "model parallel size must be 1 or 8"
|
)
|
||||||
|
model_parallel_size = input(
|
||||||
|
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
|
||||||
|
)
|
||||||
|
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
}, "model parallel size must be 1 or 8"
|
||||||
|
|
||||||
return checkpoint_dir, model_parallel_size
|
return checkpoint_dir, model_parallel_size
|
||||||
|
|
||||||
def write_output_yaml(
|
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
|
||||||
self,
|
|
||||||
checkpoint_dir,
|
|
||||||
model_parallel_size,
|
|
||||||
yaml_output_path
|
|
||||||
):
|
|
||||||
default_conf_path = pkg_resources.resource_filename(
|
default_conf_path = pkg_resources.resource_filename(
|
||||||
'llama_toolchain',
|
"llama_toolchain", "data/default_inference_config.yaml"
|
||||||
'data/default_inference_config.yaml'
|
|
||||||
)
|
)
|
||||||
with open(default_conf_path, "r") as f:
|
with open(default_conf_path, "r") as f:
|
||||||
yaml_content = f.read()
|
yaml_content = f.read()
|
||||||
|
@ -60,7 +68,7 @@ class InferenceConfigure(Subcommand):
|
||||||
model_parallel_size=model_parallel_size,
|
model_parallel_size=model_parallel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(yaml_output_path, 'w') as yaml_file:
|
with open(yaml_output_path, "w") as yaml_file:
|
||||||
yaml_file.write(yaml_content.strip())
|
yaml_file.write(yaml_content.strip())
|
||||||
|
|
||||||
print(f"YAML configuration has been written to {yaml_output_path}")
|
print(f"YAML configuration has been written to {yaml_output_path}")
|
||||||
|
@ -69,8 +77,9 @@ class InferenceConfigure(Subcommand):
|
||||||
checkpoint_dir, model_parallel_size = self.read_user_inputs()
|
checkpoint_dir, model_parallel_size = self.read_user_inputs()
|
||||||
checkpoint_dir = os.path.expanduser(checkpoint_dir)
|
checkpoint_dir = os.path.expanduser(checkpoint_dir)
|
||||||
|
|
||||||
assert Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir(), \
|
assert (
|
||||||
f"{checkpoint_dir} does not exist or it not a directory"
|
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
|
||||||
|
), f"{checkpoint_dir} does not exist or it not a directory"
|
||||||
|
|
||||||
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
|
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
|
||||||
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"
|
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
@ -40,10 +46,7 @@ class InferenceStart(Subcommand):
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--config",
|
"--config", type=str, help="Path to config file", default="inference"
|
||||||
type=str,
|
|
||||||
help="Path to config file",
|
|
||||||
default="inference"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
|
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.download import Download
|
from llama_toolchain.cli.download import Download
|
||||||
|
@ -27,9 +33,8 @@ class LlamaCLIParser:
|
||||||
|
|
||||||
# Import sub-commands from agentic_system if they exist
|
# Import sub-commands from agentic_system if they exist
|
||||||
try:
|
try:
|
||||||
from llama_agentic_system.cli.subcommand_modules import (
|
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
|
||||||
SUBCOMMAND_MODULES,
|
|
||||||
)
|
|
||||||
for module in SUBCOMMAND_MODULES:
|
for module in SUBCOMMAND_MODULES:
|
||||||
module.create(subparsers)
|
module.create(subparsers)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
from llama_models.llama3_1.api.interface import (
|
||||||
|
list_jinja_templates,
|
||||||
|
render_jinja_template,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_models.llama3_1.api.interface import render_jinja_template, list_jinja_templates
|
|
||||||
|
|
||||||
|
|
||||||
class ModelTemplate(Subcommand):
|
class ModelTemplate(Subcommand):
|
||||||
|
|
|
@ -1,3 +1,10 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
class Subcommand:
|
class Subcommand:
|
||||||
"""All llama cli subcommands must inherit from this class"""
|
"""All llama cli subcommands must inherit from this class"""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import URL
|
from llama_models.llama3_1.api.datatypes import URL
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint(BaseModel):
|
class Checkpoint(BaseModel):
|
||||||
|
|
|
@ -1,2 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .datatypes import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
from .endpoints import * # noqa: F401 F403
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -1,2 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .datatypes import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
from .endpoints import * # noqa: F401 F403
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,2 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .datatypes import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
from .endpoints import * # noqa: F401 F403
|
||||||
|
|
|
@ -1,14 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from hydra.core.config_store import ConfigStore
|
from hydra.core.config_store import ConfigStore
|
||||||
|
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .datatypes import QuantizationConfig
|
from .datatypes import QuantizationConfig
|
||||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
|
||||||
|
|
||||||
|
|
||||||
class ImplType(Enum):
|
class ImplType(Enum):
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
@ -26,9 +32,7 @@ class Fp8QuantizationConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Bf16QuantizationConfig(BaseModel):
|
class Bf16QuantizationConfig(BaseModel):
|
||||||
type: Literal[QuantizationType.bf16.value] = (
|
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||||
QuantizationType.bf16.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
QuantizationConfig = Annotated[
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
from typing import Optional, Protocol
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .api.config import ImplType, InferenceConfig
|
from .api.config import ImplType, InferenceConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from urllib.request import getproxies
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -9,12 +17,16 @@ from .api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
InstructModel,
|
|
||||||
Inference,
|
Inference,
|
||||||
|
InstructModel,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
print(getproxies())
|
||||||
|
# import sys
|
||||||
|
# sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
class InferenceClient(Inference):
|
class InferenceClient(Inference):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from llama_toolchain.inference.api import (
|
|
||||||
ChatCompletionResponseEventType,
|
from llama_toolchain.inference.api import ChatCompletionResponseEventType
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
class LogEvent:
|
||||||
|
@ -30,4 +34,3 @@ class EventLogger:
|
||||||
yield LogEvent(event.delta, color="yellow", end="")
|
yield LogEvent(event.delta, color="yellow", end="")
|
||||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
yield LogEvent("")
|
yield LogEvent("")
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import StopReason
|
from llama_models.llama3_1.api.datatypes import StopReason
|
||||||
|
|
||||||
from .api.config import (
|
from .api.config import InlineImplConfig
|
||||||
CheckpointQuantizationFormat,
|
|
||||||
CheckpointType,
|
|
||||||
InlineImplConfig,
|
|
||||||
)
|
|
||||||
from .api.datatypes import (
|
from .api.datatypes import (
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
QuantizationConfig,
|
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
@ -8,7 +14,7 @@ try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
||||||
print("Using efficient FP8 operators in FBGEMM.")
|
print("Using efficient FP8 operators in FBGEMM.")
|
||||||
except (ImportError, ModuleNotFoundError):
|
except ImportError:
|
||||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -57,8 +63,8 @@ def ffn_swiglu(
|
||||||
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
||||||
)
|
)
|
||||||
|
|
||||||
(B, T, D) = x.shape
|
(B, T, D) = x.shape # noqa: N806
|
||||||
(HD_L, D_) = w1.shape
|
(HD_L, D_) = w1.shape # noqa: N806
|
||||||
assert D_ == D
|
assert D_ == D
|
||||||
|
|
||||||
assert isinstance(w1, Tensor)
|
assert isinstance(w1, Tensor)
|
||||||
|
@ -153,8 +159,8 @@ def ffn_swiglu_fp8_dynamic(
|
||||||
num_tokens: Optional[Tensor] = None,
|
num_tokens: Optional[Tensor] = None,
|
||||||
is_memory_bounded: bool = False,
|
is_memory_bounded: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
(B, T, D) = x.shape
|
(B, T, D) = x.shape # noqa: N806
|
||||||
HD_L = w1.shape[0]
|
HD_L = w1.shape[0] # noqa: N806
|
||||||
assert HD_L == w3.shape[0]
|
assert HD_L == w3.shape[0]
|
||||||
x1 = fc_fp8_dynamic(
|
x1 = fc_fp8_dynamic(
|
||||||
x.view(B * T, D),
|
x.view(B * T, D),
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
@ -9,13 +15,13 @@ import torch
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
|
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api.config import (
|
from llama_toolchain.inference.api.config import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
InlineImplConfig,
|
InlineImplConfig,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.datatypes import QuantizationType
|
from llama_toolchain.inference.api.datatypes import QuantizationType
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +30,7 @@ def is_fbgemm_available() -> bool:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except (ImportError, ModuleNotFoundError):
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,11 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
@ -5,7 +11,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
|
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
|
||||||
from hypothesis import given, settings, strategies as st
|
from hypothesis import given, settings, strategies as st
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
@ -26,29 +32,25 @@ class FP8Tests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
def test_fp8_ffn(
|
def test_fp8_ffn(
|
||||||
self,
|
self,
|
||||||
D: int,
|
D: int, # noqa
|
||||||
HD_L: int,
|
HD_L: int,
|
||||||
B: int,
|
B: int,
|
||||||
T: int,
|
T: int,
|
||||||
UB: float,
|
UB: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||||
w1 = (
|
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||||
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||||
)
|
|
||||||
w3 = (
|
|
||||||
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
|
||||||
)
|
|
||||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||||
|
|
||||||
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||||
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||||
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||||
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||||
|
|
||||||
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||||
(B, T, D) = x.shape
|
(B, T, D) = x.shape # noqa: N806
|
||||||
(HD_L, D_) = w1.shape
|
(HD_L, D_) = w1.shape # noqa: N806
|
||||||
assert D_ == D
|
assert D_ == D
|
||||||
|
|
||||||
x1 = x.view(B * T, D) @ w1.T
|
x1 = x.view(B * T, D) @ w1.T
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,2 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .datatypes import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
from .endpoints import * # noqa: F401 F403
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from pyopenapi import webmethod
|
from pyopenapi import webmethod
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
from typing import Protocol
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from pyopenapi import webmethod
|
from typing import Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel # noqa: F401
|
||||||
|
|
||||||
|
from pyopenapi import webmethod # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class Models(Protocol):
|
class Models(Protocol): ...
|
||||||
...
|
|
||||||
|
|
|
@ -1,2 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .datatypes import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
from .endpoints import * # noqa: F401 F403
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
|
@ -1,2 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .datatypes import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
from .endpoints import * # noqa: F401 F403
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol, Union
|
from typing import List, Protocol, Union
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# supress warnings and spew of logs from hugging face
|
# supress warnings and spew of logs from hugging face
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from codeshield.cs import CodeShield
|
from codeshield.cs import CodeShield
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
@ -5,7 +11,11 @@ from llama_models.llama3_1.api.datatypes import Message
|
||||||
|
|
||||||
parent_dir = "../.."
|
parent_dir = "../.."
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
from llama_toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse
|
from llama_toolchain.safety.shields.base import (
|
||||||
|
OnViolationAction,
|
||||||
|
ShieldBase,
|
||||||
|
ShieldResponse,
|
||||||
|
)
|
||||||
|
|
||||||
_INSTANCE = None
|
_INSTANCE = None
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from string import Template
|
from string import Template
|
||||||
|
@ -100,7 +106,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
def instance(
|
def instance(
|
||||||
on_violation_action=OnViolationAction.RAISE,
|
on_violation_action=OnViolationAction.RAISE,
|
||||||
model_dir: str = None,
|
model_dir: str = None,
|
||||||
excluded_categories: List[str] = [],
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
) -> "LlamaGuardShield":
|
) -> "LlamaGuardShield":
|
||||||
|
@ -119,7 +125,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
self,
|
self,
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
model_dir: str = None,
|
model_dir: str = None,
|
||||||
excluded_categories: List[str] = [],
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -129,6 +135,8 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||||
|
|
||||||
|
if excluded_categories is None:
|
||||||
|
excluded_categories = []
|
||||||
assert len(excluded_categories) == 0 or all(
|
assert len(excluded_categories) == 0 or all(
|
||||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
@ -6,7 +12,7 @@ from llama_models.llama3_1.api.datatypes import Message, Role
|
||||||
from .base import OnViolationAction, ShieldBase, ShieldResponse
|
from .base import OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception):
|
class SafetyException(Exception): # noqa: N818
|
||||||
def __init__(self, response: ShieldResponse):
|
def __init__(self, response: ShieldResponse):
|
||||||
self.response = response
|
self.response = response
|
||||||
super().__init__(response.violation_return_message)
|
super().__init__(response.violation_return_message)
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
|
@ -1,5 +1,11 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
TMPDIR=$(mktemp -d)
|
TMPDIR=$(mktemp -d)
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,11 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
PYTHONPATH=/data/users/rsm/llama-models:/data/users/rsm/llama-toolchain:/data/users/rsm/llama-agentic-system:../../../oss-ops:../.. python -m toolchain.spec.generate
|
PYTHONPATH=/data/users/rsm/llama-models:/data/users/rsm/llama-toolchain:/data/users/rsm/llama-agentic-system:../../../oss-ops:../.. python -m toolchain.spec.generate
|
||||||
|
|
|
@ -1,2 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .datatypes import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
from .endpoints import * # noqa: F401 F403
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import getpass
|
import getpass
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
16
setup.py
16
setup.py
|
@ -1,9 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
# Function to read the requirements.txt file
|
# Function to read the requirements.txt file
|
||||||
def read_requirements():
|
def read_requirements():
|
||||||
with open('requirements.txt') as req:
|
with open("requirements.txt") as req:
|
||||||
content = req.readlines()
|
content = req.readlines()
|
||||||
return [line.strip() for line in content]
|
return [line.strip() for line in content]
|
||||||
|
|
||||||
|
@ -14,11 +20,7 @@ setup(
|
||||||
author="Meta Llama",
|
author="Meta Llama",
|
||||||
author_email="rsm@meta.com",
|
author_email="rsm@meta.com",
|
||||||
description="Llama toolchain",
|
description="Llama toolchain",
|
||||||
entry_points={
|
entry_points={"console_scripts": ["llama = llama_toolchain.cli.llama:main"]},
|
||||||
"console_scripts": [
|
|
||||||
'llama = llama_toolchain.cli.llama:main'
|
|
||||||
]
|
|
||||||
},
|
|
||||||
long_description=open("README.md").read(),
|
long_description=open("README.md").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
url="https://github.com/meta-llama/llama-toolchain",
|
url="https://github.com/meta-llama/llama-toolchain",
|
||||||
|
@ -26,5 +28,5 @@ setup(
|
||||||
classifiers=[],
|
classifiers=[],
|
||||||
python_requires=">=3.10",
|
python_requires=">=3.10",
|
||||||
install_requires=read_requirements(),
|
install_requires=read_requirements(),
|
||||||
include_package_data=True
|
include_package_data=True,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue