mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
updating license for toolchain
This commit is contained in:
parent
0e2fc9966a
commit
86fff23a9e
74 changed files with 512 additions and 94 deletions
|
@ -19,14 +19,14 @@ repos:
|
|||
# - id: no-commit-to-branch
|
||||
# args: ['--branch=main']
|
||||
|
||||
# - repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
# rev: v1.5.4
|
||||
# hooks:
|
||||
# - id: insert-license
|
||||
# files: \.py$|\.sh$
|
||||
# args:
|
||||
# - --license-filepath
|
||||
# - docs/license_header.txt
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
rev: v1.5.4
|
||||
hooks:
|
||||
- id: insert-license
|
||||
files: \.py$|\.sh$
|
||||
args:
|
||||
- --license-filepath
|
||||
- docs/license_header.txt
|
||||
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
|
||||
|
|
16
README.md
16
README.md
|
@ -1,7 +1,7 @@
|
|||
This repo contains the API specifications for various parts of the Llama Stack.
|
||||
The Stack consists of toolchain-apis and agentic-apis.
|
||||
The Stack consists of toolchain-apis and agentic-apis.
|
||||
|
||||
The tool chain apis that are covered --
|
||||
The tool chain apis that are covered --
|
||||
- inference / batch inference
|
||||
- post training
|
||||
- reward model scoring
|
||||
|
@ -10,7 +10,7 @@ The tool chain apis that are covered --
|
|||
|
||||
## Running FP8
|
||||
|
||||
You need `fbgemm-gpu` package which requires torch >= 2.4.0 (currently only in nightly, but releasing shortly...).
|
||||
You need `fbgemm-gpu` package which requires torch >= 2.4.0 (currently only in nightly, but releasing shortly...).
|
||||
|
||||
```bash
|
||||
ENV=fp8_env
|
||||
|
@ -21,19 +21,19 @@ pip3 install -r fp8_requirements.txt
|
|||
```
|
||||
|
||||
|
||||
### Generate OpenAPI specs
|
||||
### Generate OpenAPI specs
|
||||
|
||||
Set up virtual environment
|
||||
Set up virtual environment
|
||||
|
||||
```
|
||||
python3 -m venv ~/.venv/toolchain/
|
||||
python3 -m venv ~/.venv/toolchain/
|
||||
source ~/.venv/toolchain/bin/activate
|
||||
|
||||
with-proxy pip3 install -r requirements.txt
|
||||
with-proxy pip3 install -r requirements.txt
|
||||
|
||||
```
|
||||
|
||||
Run the generate.sh script
|
||||
Run the generate.sh script
|
||||
|
||||
```
|
||||
cd source && sh generate.sh
|
||||
|
|
5
docs/license_header.txt
Normal file
5
docs/license_header.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the terms described found in the
|
||||
LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import textwrap
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,10 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pkg_resources
|
||||
import textwrap
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
||||
|
||||
|
@ -36,21 +43,22 @@ class InferenceConfigure(Subcommand):
|
|||
pass
|
||||
|
||||
def read_user_inputs(self):
|
||||
checkpoint_dir = input("Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): ")
|
||||
model_parallel_size = input("Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): ")
|
||||
assert model_parallel_size.isdigit() and int(model_parallel_size) in {1, 8}, "model parallel size must be 1 or 8"
|
||||
checkpoint_dir = input(
|
||||
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
|
||||
)
|
||||
model_parallel_size = input(
|
||||
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
|
||||
)
|
||||
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
|
||||
1,
|
||||
8,
|
||||
}, "model parallel size must be 1 or 8"
|
||||
|
||||
return checkpoint_dir, model_parallel_size
|
||||
|
||||
def write_output_yaml(
|
||||
self,
|
||||
checkpoint_dir,
|
||||
model_parallel_size,
|
||||
yaml_output_path
|
||||
):
|
||||
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
|
||||
default_conf_path = pkg_resources.resource_filename(
|
||||
'llama_toolchain',
|
||||
'data/default_inference_config.yaml'
|
||||
"llama_toolchain", "data/default_inference_config.yaml"
|
||||
)
|
||||
with open(default_conf_path, "r") as f:
|
||||
yaml_content = f.read()
|
||||
|
@ -60,7 +68,7 @@ class InferenceConfigure(Subcommand):
|
|||
model_parallel_size=model_parallel_size,
|
||||
)
|
||||
|
||||
with open(yaml_output_path, 'w') as yaml_file:
|
||||
with open(yaml_output_path, "w") as yaml_file:
|
||||
yaml_file.write(yaml_content.strip())
|
||||
|
||||
print(f"YAML configuration has been written to {yaml_output_path}")
|
||||
|
@ -69,8 +77,9 @@ class InferenceConfigure(Subcommand):
|
|||
checkpoint_dir, model_parallel_size = self.read_user_inputs()
|
||||
checkpoint_dir = os.path.expanduser(checkpoint_dir)
|
||||
|
||||
assert Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir(), \
|
||||
f"{checkpoint_dir} does not exist or it not a directory"
|
||||
assert (
|
||||
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
|
||||
), f"{checkpoint_dir} does not exist or it not a directory"
|
||||
|
||||
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
|
||||
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
|
||||
|
@ -40,10 +46,7 @@ class InferenceStart(Subcommand):
|
|||
default=False,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to config file",
|
||||
default="inference"
|
||||
"--config", type=str, help="Path to config file", default="inference"
|
||||
)
|
||||
|
||||
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_toolchain.cli.download import Download
|
||||
|
@ -27,9 +33,8 @@ class LlamaCLIParser:
|
|||
|
||||
# Import sub-commands from agentic_system if they exist
|
||||
try:
|
||||
from llama_agentic_system.cli.subcommand_modules import (
|
||||
SUBCOMMAND_MODULES,
|
||||
)
|
||||
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
|
||||
|
||||
for module in SUBCOMMAND_MODULES:
|
||||
module.create(subparsers)
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
|
||||
|
|
|
@ -1,8 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
|
||||
from llama_models.llama3_1.api.interface import (
|
||||
list_jinja_templates,
|
||||
render_jinja_template,
|
||||
)
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_models.llama3_1.api.interface import render_jinja_template, list_jinja_templates
|
||||
|
||||
|
||||
class ModelTemplate(Subcommand):
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
class Subcommand:
|
||||
"""All llama cli subcommands must inherit from this class"""
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import URL
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Checkpoint(BaseModel):
|
||||
|
|
|
@ -1,2 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -1,2 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,2 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import QuantizationConfig
|
||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
|
||||
class ImplType(Enum):
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
|
@ -26,9 +32,7 @@ class Fp8QuantizationConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Bf16QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.bf16.value] = (
|
||||
QuantizationType.bf16.value
|
||||
)
|
||||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||
|
||||
|
||||
QuantizationConfig = Annotated[
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F403
|
||||
from typing import Optional, Protocol
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .api.config import ImplType, InferenceConfig
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from urllib.request import getproxies
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
||||
|
@ -9,12 +17,16 @@ from .api import (
|
|||
ChatCompletionRequest,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
InstructModel,
|
||||
Inference,
|
||||
InstructModel,
|
||||
UserMessage,
|
||||
)
|
||||
from .event_logger import EventLogger
|
||||
|
||||
print(getproxies())
|
||||
# import sys
|
||||
# sys.exit(0)
|
||||
|
||||
|
||||
class InferenceClient(Inference):
|
||||
def __init__(self, base_url: str):
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionResponseEventType,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.inference.api import ChatCompletionResponseEventType
|
||||
|
||||
|
||||
class LogEvent:
|
||||
|
@ -30,4 +34,3 @@ class EventLogger:
|
|||
yield LogEvent(event.delta, color="yellow", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
yield LogEvent("")
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import StopReason
|
||||
|
||||
from .api.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
CheckpointType,
|
||||
InlineImplConfig,
|
||||
)
|
||||
from .api.config import InlineImplConfig
|
||||
from .api.datatypes import (
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
QuantizationConfig,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
|
@ -8,7 +14,7 @@ try:
|
|||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
print("Using efficient FP8 operators in FBGEMM.")
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except ImportError:
|
||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
|
@ -57,8 +63,8 @@ def ffn_swiglu(
|
|||
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
||||
)
|
||||
|
||||
(B, T, D) = x.shape
|
||||
(HD_L, D_) = w1.shape
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
|
@ -153,8 +159,8 @@ def ffn_swiglu_fp8_dynamic(
|
|||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
(B, T, D) = x.shape
|
||||
HD_L = w1.shape[0]
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
|
@ -9,13 +15,13 @@ import torch
|
|||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.inference.api.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
InlineImplConfig,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import QuantizationType
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
@ -24,7 +30,7 @@ def is_fbgemm_available() -> bool:
|
|||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
return True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
|
@ -5,7 +11,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from fp8_impls import ffn_swiglu_fp8_dynamic, quantize_fp8, FfnQuantizeMode
|
||||
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -26,29 +32,25 @@ class FP8Tests(unittest.TestCase):
|
|||
)
|
||||
def test_fp8_ffn(
|
||||
self,
|
||||
D: int,
|
||||
D: int, # noqa
|
||||
HD_L: int,
|
||||
B: int,
|
||||
T: int,
|
||||
UB: float,
|
||||
) -> None:
|
||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
w1 = (
|
||||
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
)
|
||||
w3 = (
|
||||
torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
)
|
||||
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
|
||||
x_q = quantize_fp8(x, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
w1_q = quantize_fp8(w1, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
w3_q = quantize_fp8(w3, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
w2_q = quantize_fp8(w2, UB, mode = FfnQuantizeMode.FP8_ROWWISE)
|
||||
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
|
||||
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
(B, T, D) = x.shape
|
||||
(HD_L, D_) = w1.shape
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,2 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Protocol
|
||||
|
||||
from pyopenapi import webmethod
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
from typing import Protocol
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from pyopenapi import webmethod
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
from pyopenapi import webmethod # noqa: F401
|
||||
|
||||
|
||||
class Models(Protocol):
|
||||
...
|
||||
class Models(Protocol): ...
|
||||
|
|
|
@ -1,2 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
|
|
@ -1,2 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Protocol, Union
|
||||
from .datatypes import * # noqa: F403
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# supress warnings and spew of logs from hugging face
|
||||
import transformers
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
from termcolor import cprint
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
|
@ -5,7 +11,11 @@ from llama_models.llama3_1.api.datatypes import Message
|
|||
|
||||
parent_dir = "../.."
|
||||
sys.path.append(parent_dir)
|
||||
from llama_toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_toolchain.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
ShieldBase,
|
||||
ShieldResponse,
|
||||
)
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
|
||||
from string import Template
|
||||
|
@ -100,7 +106,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = [],
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
) -> "LlamaGuardShield":
|
||||
|
@ -119,7 +125,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = [],
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
):
|
||||
|
@ -129,6 +135,8 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
assert len(excluded_categories) == 0 or all(
|
||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import auto, Enum
|
||||
from typing import List
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
|
@ -6,7 +12,7 @@ from llama_models.llama3_1.api.datatypes import Message, Role
|
|||
from .base import OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
||||
|
||||
class SafetyException(Exception):
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
def __init__(self, response: ShieldResponse):
|
||||
self.response = response
|
||||
super().__init__(response.violation_return_message)
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
TMPDIR=$(mktemp -d)
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
set -x
|
||||
|
||||
PYTHONPATH=/data/users/rsm/llama-models:/data/users/rsm/llama-toolchain:/data/users/rsm/llama-agentic-system:../../../oss-ops:../.. python -m toolchain.spec.generate
|
||||
|
|
|
@ -1,2 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import getpass
|
||||
import os
|
||||
from typing import Optional
|
||||
|
|
16
setup.py
16
setup.py
|
@ -1,9 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
# Function to read the requirements.txt file
|
||||
def read_requirements():
|
||||
with open('requirements.txt') as req:
|
||||
with open("requirements.txt") as req:
|
||||
content = req.readlines()
|
||||
return [line.strip() for line in content]
|
||||
|
||||
|
@ -14,11 +20,7 @@ setup(
|
|||
author="Meta Llama",
|
||||
author_email="rsm@meta.com",
|
||||
description="Llama toolchain",
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
'llama = llama_toolchain.cli.llama:main'
|
||||
]
|
||||
},
|
||||
entry_points={"console_scripts": ["llama = llama_toolchain.cli.llama:main"]},
|
||||
long_description=open("README.md").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/meta-llama/llama-toolchain",
|
||||
|
@ -26,5 +28,5 @@ setup(
|
|||
classifiers=[],
|
||||
python_requires=">=3.10",
|
||||
install_requires=read_requirements(),
|
||||
include_package_data=True
|
||||
include_package_data=True,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue