mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Add toolchain from agentic system here
This commit is contained in:
parent
f6b2b2fb39
commit
95781ec85d
71 changed files with 11899 additions and 0 deletions
0
toolchain/__init__.py
Normal file
0
toolchain/__init__.py
Normal file
0
toolchain/cli/__init__.py
Normal file
0
toolchain/cli/__init__.py
Normal file
93
toolchain/cli/download.py
Normal file
93
toolchain/cli/download.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||||
|
|
||||||
|
from toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT_DIR = "/tmp/llama_toolchain_cache/"
|
||||||
|
|
||||||
|
|
||||||
|
class Download(Subcommand):
|
||||||
|
"""Llama cli for downloading llama toolchain assets"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"download",
|
||||||
|
prog="llama download",
|
||||||
|
description="Download a model from the Hugging Face Hub",
|
||||||
|
epilog=textwrap.dedent(
|
||||||
|
"""\
|
||||||
|
# Here are some examples on how to use this command:
|
||||||
|
|
||||||
|
llama download --repo-id meta-llama/Llama-2-7b-hf --hf-token <HF_TOKEN>
|
||||||
|
llama download --repo-id meta-llama/Llama-2-7b-hf --output-dir /data/my_custom_dir --hf-token <HF_TOKEN>
|
||||||
|
HF_TOKEN=<HF_TOKEN> llama download --repo-id meta-llama/Llama-2-7b-hf
|
||||||
|
|
||||||
|
The output directory will be used to load models and tokenizers for inference.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_download_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"repo_id",
|
||||||
|
type=str,
|
||||||
|
help="Name of the repository on Hugging Face Hub eg. llhf/Meta-Llama-3.1-70B-Instruct",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help=f"Directory in which to save the model. Defaults to `{DEFAULT_OUTPUT_DIR}<model_name>`.",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--hf-token",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default=os.getenv("HF_TOKEN", None),
|
||||||
|
help="Hugging Face API token. Needed for gated models like Llama2. Will also try to read environment variable `HF_TOKEN` as default.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_download_cmd(self, args: argparse.Namespace):
|
||||||
|
model_name = args.repo_id.split("/")[-1]
|
||||||
|
|
||||||
|
os.makedirs(DEFAULT_OUTPUT_DIR, exist_ok=True)
|
||||||
|
output_dir = args.output_dir
|
||||||
|
if output_dir is None:
|
||||||
|
model_name = args.repo_id.split("/")[-1]
|
||||||
|
output_dir = Path(DEFAULT_OUTPUT_DIR) / model_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
true_output_dir = snapshot_download(
|
||||||
|
args.repo_id,
|
||||||
|
local_dir=output_dir,
|
||||||
|
# "auto" will download to cache_dir and symlink files to local_dir
|
||||||
|
# avoiding unnecessary duplicate copies
|
||||||
|
local_dir_use_symlinks="auto",
|
||||||
|
token=args.hf_token,
|
||||||
|
)
|
||||||
|
except GatedRepoError:
|
||||||
|
self.parser.error(
|
||||||
|
"It looks like you are trying to access a gated repository. Please ensure you "
|
||||||
|
"have access to the repository and have provided the proper Hugging Face API token "
|
||||||
|
"using the option `--hf-token` or by running `huggingface-cli login`."
|
||||||
|
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||||
|
)
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
self.parser.error(
|
||||||
|
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.parser.error(e)
|
||||||
|
|
||||||
|
print(f"Successfully downloaded model to {true_output_dir}")
|
28
toolchain/cli/inference/inference.py
Normal file
28
toolchain/cli/inference/inference.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import argparse
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from toolchain.cli.inference.start import InferenceStart
|
||||||
|
from toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceParser(Subcommand):
|
||||||
|
"""Llama cli for inference apis"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"inference",
|
||||||
|
prog="llama inference",
|
||||||
|
description="Run inference on a llama model",
|
||||||
|
epilog=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
llama inference start <options>
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = self.parser.add_subparsers(title="inference_subcommands")
|
||||||
|
|
||||||
|
# Add sub-commandsa
|
||||||
|
InferenceStart.create(subparsers)
|
53
toolchain/cli/inference/start.py
Normal file
53
toolchain/cli/inference/start.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import argparse
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
from toolchain.inference.server import main as inference_server_init
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceStart(Subcommand):
|
||||||
|
"""Llama Inference cli for starting inference server"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"start",
|
||||||
|
prog="llama inference start",
|
||||||
|
description="Start an inference server",
|
||||||
|
epilog=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
llama inference start <options>
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_inference_start_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
help="Port to run the server on. Defaults to 5000",
|
||||||
|
default=5000,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--disable-ipv6",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable IPv6 support",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
help="Path to config file",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
inference_server_init(
|
||||||
|
config_path=args.config,
|
||||||
|
port=args.port,
|
||||||
|
disable_ipv6=args.disable_ipv6,
|
||||||
|
)
|
40
toolchain/cli/llama.py
Normal file
40
toolchain/cli/llama.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from toolchain.cli.download import Download
|
||||||
|
from toolchain.cli.inference.inference import InferenceParser
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCLIParser:
|
||||||
|
"""Defines CLI parser for Llama CLI"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = argparse.ArgumentParser(
|
||||||
|
prog="llama",
|
||||||
|
description="Welcome to the LLama toolchain cli",
|
||||||
|
add_help=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default command is to print help
|
||||||
|
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||||
|
|
||||||
|
subparsers = self.parser.add_subparsers(title="subcommands")
|
||||||
|
|
||||||
|
# Add sub-commands
|
||||||
|
Download.create(subparsers)
|
||||||
|
InferenceParser.create(subparsers)
|
||||||
|
|
||||||
|
def parse_args(self) -> argparse.Namespace:
|
||||||
|
return self.parser.parse_args()
|
||||||
|
|
||||||
|
def run(self, args: argparse.Namespace) -> None:
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = LlamaCLIParser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
parser.run(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
12
toolchain/cli/subcommand.py
Normal file
12
toolchain/cli/subcommand.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
class Subcommand:
|
||||||
|
"""All llama cli subcommands must inherit from this class"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
pass
|
25
toolchain/common/deployment_types.py
Normal file
25
toolchain/common/deployment_types.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from models.llama3.datatypes import URL
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RestAPIMethod(Enum):
|
||||||
|
GET = "GET"
|
||||||
|
POST = "POST"
|
||||||
|
PUT = "PUT"
|
||||||
|
DELETE = "DELETE"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RestAPIExecutionConfig(BaseModel):
|
||||||
|
url: URL
|
||||||
|
method: RestAPIMethod
|
||||||
|
params: Optional[Dict[str, str]] = None
|
||||||
|
headers: Optional[Dict[str, str]] = None
|
||||||
|
body: Optional[Dict[str, str]] = None
|
7
toolchain/common/training_types.py
Normal file
7
toolchain/common/training_types.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from models.llama3.datatypes import URL
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Checkpoint(BaseModel):
|
||||||
|
iters: int
|
||||||
|
path: URL
|
9
toolchain/configs/ashwin.yaml
Normal file
9
toolchain/configs/ashwin.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
model_inference_config:
|
||||||
|
impl_type: "inline"
|
||||||
|
inline_config:
|
||||||
|
checkpoint_type: "pytorch"
|
||||||
|
checkpoint_dir: /home/ashwin/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000
|
||||||
|
tokenizer_path: /home/ashwin/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model
|
||||||
|
model_parallel_size: 1
|
||||||
|
max_seq_len: 2048
|
||||||
|
max_batch_size: 1
|
9
toolchain/configs/chrisluc.yaml
Normal file
9
toolchain/configs/chrisluc.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
model_inference_config:
|
||||||
|
impl_type: "inline"
|
||||||
|
inline_config:
|
||||||
|
checkpoint_type: "pytorch"
|
||||||
|
checkpoint_dir: /home/chrisluc/models/Meta-Llama-3.1-8B-Instruct-20240710150000
|
||||||
|
tokenizer_path: /home/chrisluc/models/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model
|
||||||
|
model_parallel_size: 1
|
||||||
|
max_seq_len: 2048
|
||||||
|
max_batch_size: 1
|
9
toolchain/configs/default.yaml
Normal file
9
toolchain/configs/default.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
model_inference_config:
|
||||||
|
impl_type: "inline"
|
||||||
|
inline_config:
|
||||||
|
checkpoint_type: "pytorch"
|
||||||
|
checkpoint_dir: /home/dalton/models/Meta-Llama-3.1-8B-Instruct-20240710150000
|
||||||
|
tokenizer_path: /home/dalton/models/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model
|
||||||
|
model_parallel_size: 1
|
||||||
|
max_seq_len: 2048
|
||||||
|
max_batch_size: 1
|
9
toolchain/configs/hjshah.yaml
Normal file
9
toolchain/configs/hjshah.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
model_inference_config:
|
||||||
|
impl_type: "inline"
|
||||||
|
inline_config:
|
||||||
|
checkpoint_type: "pytorch"
|
||||||
|
checkpoint_dir: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000
|
||||||
|
tokenizer_path: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model
|
||||||
|
model_parallel_size: 1
|
||||||
|
max_seq_len: 2048
|
||||||
|
max_batch_size: 1
|
9
toolchain/configs/long_seqlen.yaml
Normal file
9
toolchain/configs/long_seqlen.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
model_inference_config:
|
||||||
|
impl_type: "inline"
|
||||||
|
inline_config:
|
||||||
|
checkpoint_type: "pytorch"
|
||||||
|
checkpoint_dir: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000
|
||||||
|
tokenizer_path: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model
|
||||||
|
model_parallel_size: 1
|
||||||
|
max_seq_len: 8192
|
||||||
|
max_batch_size: 1
|
2
toolchain/dataset/api/__init__.py
Normal file
2
toolchain/dataset/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .datatypes import * # noqa: F401 F403
|
||||||
|
from .endpoints import * # noqa: F401 F403
|
27
toolchain/dataset/api/datatypes.py
Normal file
27
toolchain/dataset/api/datatypes.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
from models.llama3.datatypes import URL
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainEvalDatasetColumnType(Enum):
|
||||||
|
dialog = "dialog"
|
||||||
|
text = "text"
|
||||||
|
media = "media"
|
||||||
|
number = "number"
|
||||||
|
json = "json"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainEvalDataset(BaseModel):
|
||||||
|
"""Dataset to be used for training or evaluating language models."""
|
||||||
|
|
||||||
|
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
||||||
|
|
||||||
|
columns: Dict[str, TrainEvalDatasetColumnType]
|
||||||
|
content_url: URL
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
36
toolchain/dataset/api/endpoints.py
Normal file
36
toolchain/dataset/api/endpoints.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CreateDatasetRequest(BaseModel):
|
||||||
|
"""Request to create a dataset."""
|
||||||
|
|
||||||
|
uuid: str
|
||||||
|
dataset: TrainEvalDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Datasets(Protocol):
|
||||||
|
@webmethod(route="/datasets/create")
|
||||||
|
def create_dataset(
|
||||||
|
self,
|
||||||
|
request: CreateDatasetRequest,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/datasets/get")
|
||||||
|
def get_dataset(
|
||||||
|
self,
|
||||||
|
dataset_uuid: str,
|
||||||
|
) -> TrainEvalDataset: ...
|
||||||
|
|
||||||
|
@webmethod(route="/datasets/delete")
|
||||||
|
def delete_dataset(
|
||||||
|
self,
|
||||||
|
dataset_uuid: str,
|
||||||
|
) -> None: ...
|
2
toolchain/evaluations/api/__init__.py
Normal file
2
toolchain/evaluations/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .datatypes import * # noqa: F401 F403
|
||||||
|
from .endpoints import * # noqa: F401 F403
|
29
toolchain/evaluations/api/datatypes.py
Normal file
29
toolchain/evaluations/api/datatypes.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TextGenerationMetric(Enum):
|
||||||
|
perplexity = "perplexity"
|
||||||
|
rouge = "rouge"
|
||||||
|
bleu = "bleu"
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAnsweringMetric(Enum):
|
||||||
|
em = "em"
|
||||||
|
f1 = "f1"
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationMetric(Enum):
|
||||||
|
rouge = "rouge"
|
||||||
|
bleu = "bleu"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationJob(BaseModel):
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationJobLogStream(BaseModel):
|
||||||
|
|
||||||
|
job_uuid: str
|
93
toolchain/evaluations/api/endpoints.py
Normal file
93
toolchain/evaluations/api/endpoints.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
from typing import List, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
|
||||||
|
from models.llama3.datatypes import * # noqa: F403
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
from toolchain.dataset.api.datatypes import * # noqa: F403
|
||||||
|
from toolchain.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluateTaskRequestCommon(BaseModel):
|
||||||
|
job_uuid: str
|
||||||
|
dataset: TrainEvalDataset
|
||||||
|
|
||||||
|
checkpoint: Checkpoint
|
||||||
|
|
||||||
|
# generation params
|
||||||
|
sampling_params: SamplingParams = SamplingParams()
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
|
||||||
|
"""Request to evaluate text generation."""
|
||||||
|
|
||||||
|
metrics: List[TextGenerationMetric]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
|
||||||
|
"""Request to evaluate question answering."""
|
||||||
|
|
||||||
|
metrics: List[QuestionAnsweringMetric]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
|
||||||
|
"""Request to evaluate summarization."""
|
||||||
|
|
||||||
|
metrics: List[SummarizationMetric]
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationJobStatusResponse(BaseModel):
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluationJobArtifactsResponse(BaseModel):
|
||||||
|
"""Artifacts of a evaluation job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluations(Protocol):
|
||||||
|
@webmethod(route="/evaluate/text_generation/")
|
||||||
|
def post_evaluate_text_generation(
|
||||||
|
self,
|
||||||
|
request: EvaluateTextGenerationRequest,
|
||||||
|
) -> EvaluationJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/question_answering/")
|
||||||
|
def post_evaluate_question_answering(
|
||||||
|
self,
|
||||||
|
request: EvaluateQuestionAnsweringRequest,
|
||||||
|
) -> EvaluationJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/summarization/")
|
||||||
|
def post_evaluate_summarization(
|
||||||
|
self,
|
||||||
|
request: EvaluateSummarizationRequest,
|
||||||
|
) -> EvaluationJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/jobs")
|
||||||
|
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/job/status")
|
||||||
|
def get_evaluation_job_status(
|
||||||
|
self, job_uuid: str
|
||||||
|
) -> EvaluationJobStatusResponse: ...
|
||||||
|
|
||||||
|
# sends SSE stream of logs
|
||||||
|
@webmethod(route="/evaluate/job/logs")
|
||||||
|
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/job/cancel")
|
||||||
|
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/job/artifacts")
|
||||||
|
def get_evaluation_job_artifacts(
|
||||||
|
self, job_uuid: str
|
||||||
|
) -> EvaluationJobArtifactsResponse: ...
|
0
toolchain/inference/__init__.py
Normal file
0
toolchain/inference/__init__.py
Normal file
2
toolchain/inference/api/__init__.py
Normal file
2
toolchain/inference/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .datatypes import * # noqa: F401 F403
|
||||||
|
from .endpoints import * # noqa: F401 F403
|
146
toolchain/inference/api/config.py
Normal file
146
toolchain/inference/api/config.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from hydra.core.config_store import ConfigStore
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneratorArgs:
|
||||||
|
ckpt_dir: str
|
||||||
|
tokenizer_path: str
|
||||||
|
model_parallel_size: Optional[int] = None
|
||||||
|
max_seq_len: int = 2048
|
||||||
|
max_batch_size: int = 4
|
||||||
|
|
||||||
|
|
||||||
|
class ImplType(Enum):
|
||||||
|
inline = "inline"
|
||||||
|
remote = "remote"
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointType(Enum):
|
||||||
|
pytorch = "pytorch"
|
||||||
|
huggingface = "huggingface"
|
||||||
|
|
||||||
|
|
||||||
|
class PytorchCheckpoint(BaseModel):
|
||||||
|
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||||
|
CheckpointType.pytorch.value
|
||||||
|
)
|
||||||
|
checkpoint_dir: str
|
||||||
|
tokenizer_path: str
|
||||||
|
model_parallel_size: int
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceCheckpoint(BaseModel):
|
||||||
|
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
|
||||||
|
CheckpointType.huggingface.value
|
||||||
|
)
|
||||||
|
repo_id: str # or model_name ?
|
||||||
|
model_parallel_size: int
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCheckpointConfig(BaseModel):
|
||||||
|
checkpoint: Annotated[
|
||||||
|
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
|
||||||
|
Field(discriminator="checkpoint_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: this same config will be used when instantiating an inference server naturally
|
||||||
|
class InlineImplConfig(BaseModel):
|
||||||
|
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
||||||
|
checkpoint_config: ModelCheckpointConfig
|
||||||
|
max_seq_len: int
|
||||||
|
max_batch_size: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteImplConfig(BaseModel):
|
||||||
|
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
|
||||||
|
url: str = Field(..., description="The URL of the remote module")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInferenceConfig(BaseModel):
|
||||||
|
impl_config: Annotated[
|
||||||
|
Union[InlineImplConfig, RemoteImplConfig],
|
||||||
|
Field(discriminator="impl_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Hydra does not like unions of containers and
|
||||||
|
# Pydantic does not like Literals
|
||||||
|
# Adding a simple dataclass with custom coversion
|
||||||
|
# to config classes
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InlineImplHydraConfig:
|
||||||
|
checkpoint_type: str # "pytorch" / "HF"
|
||||||
|
# pytorch checkpoint required args
|
||||||
|
checkpoint_dir: str
|
||||||
|
tokenizer_path: str
|
||||||
|
model_parallel_size: int
|
||||||
|
max_seq_len: int
|
||||||
|
max_batch_size: int = 1
|
||||||
|
# TODO: huggingface checkpoint required args
|
||||||
|
|
||||||
|
def convert_to_inline_impl_config(self):
|
||||||
|
if self.checkpoint_type == "pytorch":
|
||||||
|
return InlineImplConfig(
|
||||||
|
checkpoint_config=ModelCheckpointConfig(
|
||||||
|
checkpoint=PytorchCheckpoint(
|
||||||
|
checkpoint_type=CheckpointType.pytorch.value,
|
||||||
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
|
tokenizer_path=self.tokenizer_path,
|
||||||
|
model_parallel_size=self.model_parallel_size,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
max_batch_size=self.max_batch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("HF Checkpoint not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RemoteImplHydraConfig:
|
||||||
|
url: str
|
||||||
|
|
||||||
|
def convert_to_remote_impl_config(self):
|
||||||
|
return RemoteImplConfig(
|
||||||
|
url=self.url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInferenceHydraConfig:
|
||||||
|
impl_type: str
|
||||||
|
inline_config: Optional[InlineImplHydraConfig] = None
|
||||||
|
remote_config: Optional[RemoteImplHydraConfig] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.impl_type in ["inline", "remote"]
|
||||||
|
if self.impl_type == "inline":
|
||||||
|
assert self.inline_config is not None
|
||||||
|
if self.impl_type == "remote":
|
||||||
|
assert self.remote_config is not None
|
||||||
|
|
||||||
|
def convert_to_model_inferene_config(self):
|
||||||
|
if self.impl_type == "inline":
|
||||||
|
inline_config = InlineImplHydraConfig(**self.inline_config)
|
||||||
|
return ModelInferenceConfig(
|
||||||
|
impl_config=inline_config.convert_to_inline_impl_config()
|
||||||
|
)
|
||||||
|
elif self.impl_type == "remote":
|
||||||
|
remote_config = RemoteImplHydraConfig(**self.remote_config)
|
||||||
|
return ModelInferenceConfig(
|
||||||
|
impl_config=remote_config.convert_to_remote_impl_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cs = ConfigStore.instance()
|
||||||
|
cs.store(name="model_inference_config", node=ModelInferenceHydraConfig)
|
68
toolchain/inference/api/datatypes.py
Normal file
68
toolchain/inference/api/datatypes.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from models.llama3.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class LogProbConfig(BaseModel):
|
||||||
|
top_k: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QuantizationType(Enum):
|
||||||
|
bf16 = "bf16"
|
||||||
|
fp8 = "fp8"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
|
quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Bf16QuantizationConfig(BaseModel):
|
||||||
|
quantization_type: Literal[QuantizationType.bf16.value] = (
|
||||||
|
QuantizationType.bf16.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
QuantizationConfig = Annotated[
|
||||||
|
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||||
|
Field(discriminator="quantization_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
start = "start"
|
||||||
|
complete = "complete"
|
||||||
|
progress = "progress"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallParseStatus(Enum):
|
||||||
|
started = "started"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
failure = "failure"
|
||||||
|
success = "success"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallDelta(BaseModel):
|
||||||
|
content: Union[str, ToolCall]
|
||||||
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
"""Chat completion response event."""
|
||||||
|
|
||||||
|
event_type: ChatCompletionResponseEventType
|
||||||
|
delta: Union[str, ToolCallDelta]
|
||||||
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
stop_reason: Optional[StopReason] = None
|
117
toolchain/inference/api/endpoints.py
Normal file
117
toolchain/inference/api/endpoints.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
|
# this dependency is annoying and we need a forked up version anyway
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
model: PretrainedModel
|
||||||
|
content: InterleavedTextAttachment
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
quantization_config: Optional[QuantizationConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
completion_message: CompletionMessage
|
||||||
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CompletionResponseStreamChunk(BaseModel):
|
||||||
|
"""streamed completion response."""
|
||||||
|
|
||||||
|
delta: str
|
||||||
|
stop_reason: Optional[StopReason] = None
|
||||||
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchCompletionRequest(BaseModel):
|
||||||
|
model: PretrainedModel
|
||||||
|
content_batch: List[InterleavedTextAttachment]
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
quantization_config: Optional[QuantizationConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchCompletionResponse(BaseModel):
|
||||||
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: InstructModel
|
||||||
|
messages: List[Message]
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
|
# zero-shot tool definitions as input to the model
|
||||||
|
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
quantization_config: Optional[QuantizationConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseStreamChunk(BaseModel):
|
||||||
|
"""SSE-stream of these events."""
|
||||||
|
|
||||||
|
event: ChatCompletionResponseEvent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
completion_message: CompletionMessage
|
||||||
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchChatCompletionRequest(BaseModel):
|
||||||
|
model: InstructModel
|
||||||
|
messages_batch: List[List[Message]]
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
|
# zero-shot tool definitions as input to the model
|
||||||
|
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
quantization_config: Optional[QuantizationConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInference(Protocol):
|
||||||
|
|
||||||
|
@webmethod(route="/inference/completion")
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/chat_completion")
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch_completion")
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
request: BatchCompletionRequest,
|
||||||
|
) -> List[CompletionResponse]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch_chat_completion")
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
request: BatchChatCompletionRequest,
|
||||||
|
) -> List[ChatCompletionResponse]: ...
|
12
toolchain/inference/api_instance.py
Normal file
12
toolchain/inference/api_instance.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from .api.config import ImplType, ModelInferenceConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_inference_api_instance(config: ModelInferenceConfig):
|
||||||
|
if config.impl_config.impl_type == ImplType.inline.value:
|
||||||
|
from .inference import ModelInferenceImpl
|
||||||
|
|
||||||
|
return ModelInferenceImpl(config.impl_config)
|
||||||
|
|
||||||
|
from .client import ModelInferenceClient
|
||||||
|
|
||||||
|
return ModelInferenceClient(config.impl_config.url)
|
73
toolchain/inference/client.py
Normal file
73
toolchain/inference/client.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .api.endpoints import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionRequest,
|
||||||
|
InstructModel,
|
||||||
|
ModelInference,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInferenceClient(ModelInference):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/inference/chat_completion",
|
||||||
|
data=request.json(),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data:"):
|
||||||
|
data = line[len("data: ") :]
|
||||||
|
try:
|
||||||
|
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||||
|
except Exception as e:
|
||||||
|
print(data)
|
||||||
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int):
|
||||||
|
client = ModelInferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
message = UserMessage(content="hello world, help me out here")
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model=InstructModel.llama3_70b_chat,
|
||||||
|
messages=[message],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for event in client.chat_completion(
|
||||||
|
ChatCompletionRequest(
|
||||||
|
model=InstructModel.llama3_70b_chat,
|
||||||
|
messages=[message],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
print(event)
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int):
|
||||||
|
asyncio.run(run_main(host, port))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
298
toolchain/inference/generation.py
Normal file
298
toolchain/inference/generation.py
Normal file
|
@ -0,0 +1,298 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator, List, Optional, TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
get_model_parallel_rank,
|
||||||
|
initialize_model_parallel,
|
||||||
|
model_parallel_is_initialized,
|
||||||
|
)
|
||||||
|
from models.llama3.args import ModelArgs
|
||||||
|
from models.llama3.chat_format import ChatFormat, ModelInput
|
||||||
|
from models.llama3.datatypes import Message
|
||||||
|
from models.llama3.model import Transformer
|
||||||
|
from models.llama3.tokenizer import Tokenizer
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenResult:
|
||||||
|
token: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[List[float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionPrediction(TypedDict, total=False):
|
||||||
|
generation: str
|
||||||
|
tokens: List[str] # not required
|
||||||
|
logprobs: List[float] # not required
|
||||||
|
|
||||||
|
|
||||||
|
class Llama:
|
||||||
|
@staticmethod
|
||||||
|
def build(
|
||||||
|
ckpt_dir: str,
|
||||||
|
tokenizer_path: str,
|
||||||
|
max_seq_len: int,
|
||||||
|
max_batch_size: int,
|
||||||
|
model_parallel_size: Optional[int] = None,
|
||||||
|
seed: int = 1,
|
||||||
|
) -> "Llama":
|
||||||
|
"""
|
||||||
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_dir (str): Path to the directory containing checkpoint files.
|
||||||
|
tokenizer_path (str): Path to the tokenizer file.
|
||||||
|
max_seq_len (int): Maximum sequence length for input text.
|
||||||
|
max_batch_size (int): Maximum batch size for inference.
|
||||||
|
model_parallel_size (Optional[int], optional): Number of model parallel processes.
|
||||||
|
If not provided, it's determined from the environment. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Llama: An instance of the Llama class with the loaded model and tokenizer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If there are no checkpoint files in the specified directory,
|
||||||
|
or if the model parallel size does not match the number of checkpoint files.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
|
and loads the pre-trained model and tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group("nccl")
|
||||||
|
if not model_parallel_is_initialized():
|
||||||
|
if model_parallel_size is None:
|
||||||
|
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
|
# seed must be the same in all processes
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if local_rank > 0:
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
assert model_parallel_size == len(
|
||||||
|
checkpoints
|
||||||
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
# TODO(ashwin): this block is so we can load internal checkpoints without additional
|
||||||
|
# fuss. the final code should _not_ have this blurb
|
||||||
|
if "model" in params:
|
||||||
|
params = params["model"]
|
||||||
|
|
||||||
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
model_args.vocab_size == tokenizer.n_words
|
||||||
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
|
||||||
|
model = Transformer(model_args)
|
||||||
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
|
return Llama(model, tokenizer, model_args)
|
||||||
|
|
||||||
|
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||||
|
self.args = args
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
model_input: ModelInput,
|
||||||
|
max_gen_len: int,
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
include_stop_token: bool = False,
|
||||||
|
) -> Generator:
|
||||||
|
params = self.model.params
|
||||||
|
|
||||||
|
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
|
||||||
|
prompt_tokens = [model_input.tokens]
|
||||||
|
|
||||||
|
bsz = 1
|
||||||
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||||
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
|
if max_prompt_len >= params.max_seq_len:
|
||||||
|
cprint(
|
||||||
|
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
pad_id = self.tokenizer.pad_id
|
||||||
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||||
|
for k, t in enumerate(prompt_tokens):
|
||||||
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||||
|
|
||||||
|
prev_pos = 0
|
||||||
|
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||||
|
input_text_mask = tokens != pad_id
|
||||||
|
if min_prompt_len == total_len:
|
||||||
|
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
||||||
|
logits = self.model.forward(tokens, prev_pos)
|
||||||
|
token_logprobs = -F.cross_entropy(
|
||||||
|
input=logits.transpose(1, 2),
|
||||||
|
target=tokens,
|
||||||
|
reduction="none",
|
||||||
|
ignore_index=pad_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||||
|
|
||||||
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
|
||||||
|
if temperature > 0:
|
||||||
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||||
|
|
||||||
|
next_token = next_token.reshape(-1)
|
||||||
|
# only replace token if prompt has already been generated
|
||||||
|
next_token = torch.where(
|
||||||
|
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||||
|
)
|
||||||
|
tokens[:, cur_pos] = next_token
|
||||||
|
|
||||||
|
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||||
|
input=logits.transpose(1, 2),
|
||||||
|
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
||||||
|
reduction="none",
|
||||||
|
ignore_index=pad_id,
|
||||||
|
)
|
||||||
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
||||||
|
torch.isin(next_token, stop_tokens)
|
||||||
|
)
|
||||||
|
yield TokenResult(
|
||||||
|
token=next_token[0].item(),
|
||||||
|
text=self.tokenizer.decode(next_token.tolist()),
|
||||||
|
logprobs=(
|
||||||
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||||
|
if logprobs
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_pos = cur_pos
|
||||||
|
if all(eos_reached):
|
||||||
|
break
|
||||||
|
|
||||||
|
def text_completion(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
) -> Generator:
|
||||||
|
if (
|
||||||
|
max_gen_len is None
|
||||||
|
or max_gen_len == 0
|
||||||
|
or max_gen_len >= self.model.params.max_seq_len
|
||||||
|
):
|
||||||
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
|
prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False)
|
||||||
|
|
||||||
|
yield from self.generate(
|
||||||
|
model_input=ModelInput(tokens=prompt_tokens),
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
logprobs=logprobs,
|
||||||
|
echo=echo,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
) -> Generator:
|
||||||
|
if (
|
||||||
|
max_gen_len is None
|
||||||
|
or max_gen_len == 0
|
||||||
|
or max_gen_len >= self.model.params.max_seq_len
|
||||||
|
):
|
||||||
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
|
yield from self.generate(
|
||||||
|
model_input=self.formatter.encode_dialog_prompt(messages),
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
logprobs=logprobs,
|
||||||
|
include_stop_token=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs, p):
|
||||||
|
"""
|
||||||
|
Perform top-p (nucleus) sampling on a probability distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Probability distribution tensor.
|
||||||
|
p (float): Probability threshold for top-p sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled token indices.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||||
|
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||||
|
"""
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
173
toolchain/inference/inference.py
Normal file
173
toolchain/inference/inference.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from models.llama3.datatypes import StopReason
|
||||||
|
|
||||||
|
from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig
|
||||||
|
from .api.datatypes import (
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
|
from .api.endpoints import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionRequest,
|
||||||
|
ModelInference,
|
||||||
|
)
|
||||||
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs:
|
||||||
|
if (
|
||||||
|
config.checkpoint_config.checkpoint.checkpoint_type
|
||||||
|
== CheckpointType.pytorch.value
|
||||||
|
):
|
||||||
|
pt_checkpoint = config.checkpoint_config.checkpoint
|
||||||
|
return GeneratorArgs(
|
||||||
|
ckpt_dir=pt_checkpoint.checkpoint_dir,
|
||||||
|
tokenizer_path=pt_checkpoint.tokenizer_path,
|
||||||
|
model_parallel_size=pt_checkpoint.model_parallel_size,
|
||||||
|
max_seq_len=config.max_seq_len,
|
||||||
|
max_batch_size=config.max_batch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("HF Checkpoint not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInferenceImpl(ModelInference):
|
||||||
|
|
||||||
|
def __init__(self, config: InlineImplConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
generator_args = generator_args_from_config(self.config)
|
||||||
|
self.generator = LlamaModelParallelGenerator(
|
||||||
|
args=generator_args,
|
||||||
|
)
|
||||||
|
self.generator.start()
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
self.generator.stop()
|
||||||
|
|
||||||
|
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
logprobs = []
|
||||||
|
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
ipython = False
|
||||||
|
|
||||||
|
for token_result in self.generator.chat_completion(
|
||||||
|
messages=request.messages,
|
||||||
|
temperature=request.sampling_params.temperature,
|
||||||
|
top_p=request.sampling_params.top_p,
|
||||||
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
):
|
||||||
|
buffer += token_result.text
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||||
|
ipython = True
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.started,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
buffer = buffer[len("<|python_tag|>") :]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not request.stream:
|
||||||
|
if request.logprobs:
|
||||||
|
logprobs.append(token_result.logprob)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
else:
|
||||||
|
text = token_result.text
|
||||||
|
|
||||||
|
if ipython:
|
||||||
|
delta = ToolCallDelta(
|
||||||
|
content=text,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta = text
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=delta,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||||
|
# if someone breaks the iteration before coming here we are toast
|
||||||
|
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||||
|
if request.stream:
|
||||||
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
|
if ipython and not parsed_tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.failure,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||||
|
else:
|
||||||
|
yield ChatCompletionResponse(
|
||||||
|
content=message.content,
|
||||||
|
tool_calls=message.tool_calls,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
100
toolchain/inference/model_parallel.py
Normal file
100
toolchain/inference/model_parallel.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Generator, List, Optional
|
||||||
|
|
||||||
|
from models.llama3.chat_format import ChatFormat
|
||||||
|
from models.llama3.datatypes import Message
|
||||||
|
from models.llama3.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from .api.config import GeneratorArgs
|
||||||
|
from .generation import Llama
|
||||||
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceArgs:
|
||||||
|
messages: List[Message]
|
||||||
|
temperature: float
|
||||||
|
top_p: float
|
||||||
|
max_gen_len: int
|
||||||
|
logprobs: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunner:
|
||||||
|
def __init__(self, llama):
|
||||||
|
self.llama = llama
|
||||||
|
|
||||||
|
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||||
|
def __call__(self, task: InferenceArgs):
|
||||||
|
return self.llama.chat_completion(
|
||||||
|
task.messages,
|
||||||
|
task.temperature,
|
||||||
|
task.top_p,
|
||||||
|
task.max_gen_len,
|
||||||
|
task.logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_model_cb(args: GeneratorArgs):
|
||||||
|
llama = Llama.build(
|
||||||
|
args.ckpt_dir,
|
||||||
|
args.tokenizer_path,
|
||||||
|
args.max_seq_len,
|
||||||
|
args.max_batch_size,
|
||||||
|
)
|
||||||
|
return ModelRunner(llama)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModelParallelGenerator:
|
||||||
|
"""
|
||||||
|
This abstraction exists so
|
||||||
|
- we can run model parallel code without needing to run the CLIs via torchrun
|
||||||
|
- this also enables use model parallel code within a notebook context.
|
||||||
|
|
||||||
|
A Context Manager is used to ensure that the model parallel process is started and stopped
|
||||||
|
correctly. This does make the ergonomics a little awkward, because it isn't immediately
|
||||||
|
clear at the callsite why we need to use a context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: GeneratorArgs):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
|
# while the tool-use loop is going
|
||||||
|
self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path))
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.__enter__()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.__exit__(None, None, None)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.group = ModelParallelProcessGroup(
|
||||||
|
self.args.model_parallel_size,
|
||||||
|
init_model_cb=partial(init_model_cb, self.args),
|
||||||
|
)
|
||||||
|
self.group.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
self.group.stop()
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
) -> Generator:
|
||||||
|
req_obj = InferenceArgs(
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
gen = self.group.run_inference(req_obj)
|
||||||
|
yield from gen
|
259
toolchain/inference/parallel_utils.py
Normal file
259
toolchain/inference/parallel_utils.py
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from typing import Callable, Generator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
get_model_parallel_group,
|
||||||
|
get_model_parallel_rank,
|
||||||
|
get_model_parallel_src_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
|
|
||||||
|
|
||||||
|
_END_SENTINEL = "__end_sentinel__"
|
||||||
|
_CANCEL_SENTINEL = "__cancel_sentinel__"
|
||||||
|
|
||||||
|
|
||||||
|
def mp_rank_0() -> bool:
|
||||||
|
return get_model_parallel_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_requests(reply_socket_url: str):
|
||||||
|
if mp_rank_0():
|
||||||
|
context = zmq.Context()
|
||||||
|
reply_socket = context.socket(zmq.ROUTER)
|
||||||
|
reply_socket.connect(reply_socket_url)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
client_id, obj = maybe_get_work(reply_socket)
|
||||||
|
if obj is None:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
|
||||||
|
break
|
||||||
|
|
||||||
|
def send_obj(obj):
|
||||||
|
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
|
||||||
|
|
||||||
|
while True:
|
||||||
|
tasks = [None]
|
||||||
|
if mp_rank_0():
|
||||||
|
client_id, task = maybe_get_work(reply_socket)
|
||||||
|
# there is still an unknown unclean GeneratorExit happening resulting in a
|
||||||
|
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
||||||
|
# kind of a hack this is :/
|
||||||
|
if task != _CANCEL_SENTINEL:
|
||||||
|
tasks = [task]
|
||||||
|
|
||||||
|
torch.distributed.broadcast_object_list(
|
||||||
|
tasks,
|
||||||
|
src=get_model_parallel_src_rank(),
|
||||||
|
group=get_model_parallel_group(),
|
||||||
|
)
|
||||||
|
|
||||||
|
task = tasks[0]
|
||||||
|
if task is None:
|
||||||
|
time.sleep(0.1)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
out = yield task
|
||||||
|
if out is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
for obj in out:
|
||||||
|
updates = [None]
|
||||||
|
if mp_rank_0():
|
||||||
|
_, update = maybe_get_work(reply_socket)
|
||||||
|
if update == _CANCEL_SENTINEL:
|
||||||
|
updates = [update]
|
||||||
|
else:
|
||||||
|
# only send the update if it's not cancelled otherwise the object sits in the socket
|
||||||
|
# and gets pulled in the next request lol
|
||||||
|
send_obj(obj)
|
||||||
|
|
||||||
|
torch.distributed.broadcast_object_list(
|
||||||
|
updates,
|
||||||
|
src=get_model_parallel_src_rank(),
|
||||||
|
group=get_model_parallel_group(),
|
||||||
|
)
|
||||||
|
if updates[0] == _CANCEL_SENTINEL:
|
||||||
|
print("quitting generation loop because request was cancelled")
|
||||||
|
break
|
||||||
|
|
||||||
|
if mp_rank_0():
|
||||||
|
send_obj(_END_SENTINEL)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[debug] got exception {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
if mp_rank_0():
|
||||||
|
send_obj(e)
|
||||||
|
|
||||||
|
if mp_rank_0():
|
||||||
|
send_obj("DONE")
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_get_work(sock: zmq.Socket):
|
||||||
|
message = None
|
||||||
|
client_id = None
|
||||||
|
try:
|
||||||
|
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
||||||
|
message = pickle.loads(obj)
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
if e.errno != zmq.EAGAIN:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return client_id, message
|
||||||
|
|
||||||
|
|
||||||
|
def worker_process_entrypoint(
|
||||||
|
reply_socket_url: str,
|
||||||
|
init_model_cb: Callable,
|
||||||
|
) -> None:
|
||||||
|
model = init_model_cb()
|
||||||
|
torch.distributed.barrier()
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# run the requests co-routine which retrieves requests from the socket
|
||||||
|
# and sends responses (we provide) back to the caller
|
||||||
|
req_gen = retrieve_requests(reply_socket_url)
|
||||||
|
result = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
task = req_gen.send(result)
|
||||||
|
if isinstance(task, str) and task == _END_SENTINEL:
|
||||||
|
break
|
||||||
|
|
||||||
|
result = model(task)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
print("[debug] worker process done")
|
||||||
|
|
||||||
|
|
||||||
|
def launch_dist_group(
|
||||||
|
reply_socket_url: str,
|
||||||
|
model_parallel_size: int,
|
||||||
|
init_model_cb: Callable,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
id = uuid.uuid4().hex
|
||||||
|
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||||
|
launch_config = LaunchConfig(
|
||||||
|
max_nodes=1,
|
||||||
|
min_nodes=1,
|
||||||
|
nproc_per_node=model_parallel_size,
|
||||||
|
start_method="fork",
|
||||||
|
rdzv_backend="c10d",
|
||||||
|
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
||||||
|
rdzv_configs={"store_type": "file", "timeout": 90},
|
||||||
|
max_restarts=0,
|
||||||
|
monitor_interval=1,
|
||||||
|
run_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
||||||
|
reply_socket_url,
|
||||||
|
init_model_cb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def start_model_parallel_process(
|
||||||
|
model_parallel_size: int,
|
||||||
|
init_model_cb: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
context = zmq.Context()
|
||||||
|
request_socket = context.socket(zmq.DEALER)
|
||||||
|
|
||||||
|
# Binding the request socket to a random port
|
||||||
|
request_socket.bind("tcp://127.0.0.1:0")
|
||||||
|
|
||||||
|
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||||
|
|
||||||
|
ctx = multiprocessing.get_context("fork")
|
||||||
|
process = ctx.Process(
|
||||||
|
target=launch_dist_group,
|
||||||
|
args=(
|
||||||
|
main_process_url,
|
||||||
|
model_parallel_size,
|
||||||
|
init_model_cb,
|
||||||
|
),
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
|
||||||
|
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||||
|
|
||||||
|
request_socket.send_pyobj("READY?")
|
||||||
|
response = request_socket.recv_pyobj()
|
||||||
|
print(f"Finished model load {response}")
|
||||||
|
|
||||||
|
return request_socket, process
|
||||||
|
|
||||||
|
|
||||||
|
class ModelParallelProcessGroup:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_parallel_size: int,
|
||||||
|
init_model_cb: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model_parallel_size = model_parallel_size
|
||||||
|
self.init_model_cb = init_model_cb
|
||||||
|
self.started = False
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
assert not self.started, "process group already started"
|
||||||
|
self.request_socket, self.process = start_model_parallel_process(
|
||||||
|
self.model_parallel_size,
|
||||||
|
self.init_model_cb,
|
||||||
|
)
|
||||||
|
self.started = True
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
assert self.started, "process group not started"
|
||||||
|
if self.process.is_alive():
|
||||||
|
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
|
||||||
|
self.process.join()
|
||||||
|
self.started = False
|
||||||
|
|
||||||
|
def run_inference(self, request) -> Generator:
|
||||||
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.request_socket.send_pyobj(request)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
obj = self.request_socket.recv_pyobj()
|
||||||
|
if obj == _END_SENTINEL:
|
||||||
|
break
|
||||||
|
|
||||||
|
if isinstance(obj, Exception):
|
||||||
|
print(f"[debug] got exception {obj}")
|
||||||
|
raise obj
|
||||||
|
|
||||||
|
yield obj
|
||||||
|
except GeneratorExit as e:
|
||||||
|
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
|
||||||
|
while True:
|
||||||
|
obj = self.request_socket.recv_pyobj()
|
||||||
|
if obj == _END_SENTINEL:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.running = False
|
45
toolchain/inference/quantization/build_conda.sh
Normal file
45
toolchain/inference/quantization/build_conda.sh
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [[ $# -ne 1 ]]; then
|
||||||
|
echo "Error: Please provide the name of CONDA environment you wish to create"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ENV_NAME=$1
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
eval "$(conda shell.bash hook)"
|
||||||
|
|
||||||
|
echo "Will build env (or overwrite) named '$ENV_NAME'"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
run_build() {
|
||||||
|
# Set CUDA 9.0a targets
|
||||||
|
export CUDA_ARCH_LIST="8.0;9.0a"
|
||||||
|
export NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a"
|
||||||
|
export TORCH_CUDA_ARCH_LIST=$CUDA_ARCH_LIST
|
||||||
|
|
||||||
|
# Set up the conda environment
|
||||||
|
yes | conda remove --name $ENV_NAME --all
|
||||||
|
yes | conda create -n $ENV_NAME python=3.10
|
||||||
|
conda activate $ENV_NAME
|
||||||
|
yes | conda install --channel "nvidia/label/cuda-12.1.0" cuda
|
||||||
|
yes | conda install cuda-nvtx cuda-nvtx-dev conda-forge::nccl
|
||||||
|
|
||||||
|
|
||||||
|
# ############# Hack to get CUDA path #############
|
||||||
|
ln -s $CONDA_PREFIX/targets/x86_64-linux/include/* $CONDA_PREFIX/include/ || true
|
||||||
|
export CUDA_HOME=$CONDA_PREFIX
|
||||||
|
export CUDA_BIN_PATH=$CUDA_HOME
|
||||||
|
# #################################################
|
||||||
|
|
||||||
|
# PT nightly
|
||||||
|
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||||
|
pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
|
||||||
|
|
||||||
|
# install dependencies for `llama-agentic-system`
|
||||||
|
pip install -r fp8_requirements.txt
|
||||||
|
}
|
||||||
|
|
||||||
|
run_build
|
165
toolchain/inference/quantization/fp8_impls.py
Normal file
165
toolchain/inference/quantization/fp8_impls.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from enum import Enum, unique
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
||||||
|
print("Using efficient FP8 operators in FBGEMM.")
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class FfnQuantizeMode(Enum):
|
||||||
|
FP8_ROWWISE = "fp8_rowwise"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8ScaledWeights:
|
||||||
|
# TODO: Ugly trick so torch allows us to replace parameters
|
||||||
|
# with our custom Fp8Weights instance. Do this properly.
|
||||||
|
@property
|
||||||
|
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||||
|
return nn.Parameter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grad_fn(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||||
|
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||||
|
class Fp8RowwiseWeights(
|
||||||
|
Fp8ScaledWeights,
|
||||||
|
collections.namedtuple(
|
||||||
|
"Fp8RowwiseWeights",
|
||||||
|
["weight", "scale", "shape", "activation_scale_ub"],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def ffn_swiglu(
|
||||||
|
x: Tensor,
|
||||||
|
w1: Fp8RowwiseWeights,
|
||||||
|
w3: Fp8RowwiseWeights,
|
||||||
|
w2: Fp8RowwiseWeights,
|
||||||
|
num_tokens: Optional[Tensor] = None,
|
||||||
|
is_memory_bounded: bool = False,
|
||||||
|
) -> Tensor:
|
||||||
|
if (
|
||||||
|
isinstance(w1, Fp8ScaledWeights)
|
||||||
|
and isinstance(w3, Fp8ScaledWeights)
|
||||||
|
and isinstance(w2, Fp8ScaledWeights)
|
||||||
|
):
|
||||||
|
return ffn_swiglu_fp8_dynamic(
|
||||||
|
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
||||||
|
)
|
||||||
|
|
||||||
|
(B, T, D) = x.shape
|
||||||
|
(HD_L, D_) = w1.shape
|
||||||
|
assert D_ == D
|
||||||
|
|
||||||
|
assert isinstance(w1, Tensor)
|
||||||
|
assert isinstance(w3, Tensor)
|
||||||
|
x1 = x.view(B * T, D) @ w1.T
|
||||||
|
x2 = x.view(B * T, D) @ w3.T
|
||||||
|
z = torch.nn.functional.silu(x1) * x2
|
||||||
|
del x1, x2
|
||||||
|
assert isinstance(w2, Tensor)
|
||||||
|
return (z @ w2.T).view(B, T, D)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def quantize_fp8(
|
||||||
|
w: Tensor,
|
||||||
|
fp8_activation_scale_ub: float,
|
||||||
|
mode: Optional[FfnQuantizeMode] = None,
|
||||||
|
output_device: Optional[torch.device] = None,
|
||||||
|
) -> Fp8RowwiseWeights:
|
||||||
|
"""Quantize [n, k] weight tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||||
|
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||||
|
mode (FfnQuantizeMode): Quantization mode.
|
||||||
|
"""
|
||||||
|
activation_scale_ub = torch.tensor(
|
||||||
|
[fp8_activation_scale_ub],
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise
|
||||||
|
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||||
|
del w
|
||||||
|
return Fp8RowwiseWeights(
|
||||||
|
weight=wq,
|
||||||
|
scale=w_scale,
|
||||||
|
shape=wq.shape,
|
||||||
|
activation_scale_ub=activation_scale_ub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fc_fp8_dynamic(
|
||||||
|
x: Tensor,
|
||||||
|
w: Fp8RowwiseWeights,
|
||||||
|
activation_scale_ub: Optional[Tensor] = None,
|
||||||
|
num_tokens: Optional[Tensor] = None,
|
||||||
|
is_memory_bounded: bool = False,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Single w8a8 fc layer with dynamic row-wise scaling.
|
||||||
|
"""
|
||||||
|
if isinstance(w, Fp8RowwiseWeights):
|
||||||
|
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
|
x, num_tokens, activation_scale_ub
|
||||||
|
)
|
||||||
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||||
|
xq, w.weight, x_scale, w.scale, use_fast_accum=True
|
||||||
|
)
|
||||||
|
del xq
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def ffn_swiglu_fp8_dynamic(
|
||||||
|
x: Tensor,
|
||||||
|
w1: Fp8RowwiseWeights,
|
||||||
|
w3: Fp8RowwiseWeights,
|
||||||
|
w2: Fp8RowwiseWeights,
|
||||||
|
activation_scale_ub: Optional[Tensor] = None,
|
||||||
|
num_tokens: Optional[Tensor] = None,
|
||||||
|
is_memory_bounded: bool = False,
|
||||||
|
) -> Tensor:
|
||||||
|
(B, T, D) = x.shape
|
||||||
|
HD_L = w1.shape[0]
|
||||||
|
assert HD_L == w3.shape[0]
|
||||||
|
x1 = fc_fp8_dynamic(
|
||||||
|
x.view(B * T, D),
|
||||||
|
w1,
|
||||||
|
activation_scale_ub,
|
||||||
|
num_tokens,
|
||||||
|
is_memory_bounded,
|
||||||
|
)
|
||||||
|
x2 = fc_fp8_dynamic(
|
||||||
|
x.view(B * T, D),
|
||||||
|
w3,
|
||||||
|
activation_scale_ub,
|
||||||
|
num_tokens,
|
||||||
|
is_memory_bounded,
|
||||||
|
)
|
||||||
|
z = torch.nn.functional.silu(x1) * x2
|
||||||
|
del x1, x2
|
||||||
|
|
||||||
|
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||||
|
|
||||||
|
return z_.view(B, T, D)
|
5
toolchain/inference/quantization/fp8_requirements.txt
Normal file
5
toolchain/inference/quantization/fp8_requirements.txt
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
fairscale
|
||||||
|
fire
|
||||||
|
tiktoken
|
||||||
|
blobfile
|
||||||
|
fbgemm-gpu==0.8.0rc4
|
455
toolchain/inference/quantization/generation.py
Normal file
455
toolchain/inference/quantization/generation.py
Normal file
|
@ -0,0 +1,455 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
get_model_parallel_rank,
|
||||||
|
initialize_model_parallel,
|
||||||
|
model_parallel_is_initialized,
|
||||||
|
)
|
||||||
|
from fp8.fp8_impls import (
|
||||||
|
FfnQuantizeMode,
|
||||||
|
Fp8ScaledWeights,
|
||||||
|
load_fp8,
|
||||||
|
ModelLoadMode,
|
||||||
|
quantize_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama.model import ModelArgs, Transformer, TransformerBlock
|
||||||
|
from llama.tokenizer import ChatFormat, Dialog, Message, ModelInput, Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionPrediction(TypedDict, total=False):
|
||||||
|
generation: str
|
||||||
|
tokens: List[str] # not required
|
||||||
|
logprobs: List[float] # not required
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPrediction(TypedDict, total=False):
|
||||||
|
generation: Message
|
||||||
|
tokens: List[str] # not required
|
||||||
|
logprobs: List[float] # not required
|
||||||
|
|
||||||
|
|
||||||
|
class Llama:
|
||||||
|
@staticmethod
|
||||||
|
def build(
|
||||||
|
ckpt_dir: str,
|
||||||
|
tokenizer_path: str,
|
||||||
|
max_seq_len: int,
|
||||||
|
max_batch_size: int,
|
||||||
|
model_parallel_size: Optional[int] = None,
|
||||||
|
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.NONE,
|
||||||
|
model_load_mode: Optional[ModelLoadMode] = ModelLoadMode.BF16,
|
||||||
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
|
seed: int = 1,
|
||||||
|
) -> "Llama":
|
||||||
|
"""
|
||||||
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_dir (str): Path to the directory containing checkpoint files.
|
||||||
|
tokenizer_path (str): Path to the tokenizer file.
|
||||||
|
max_seq_len (int): Maximum sequence length for input text.
|
||||||
|
max_batch_size (int): Maximum batch size for inference.
|
||||||
|
model_parallel_size (Optional[int], optional): Number of model parallel processes.
|
||||||
|
If not provided, it's determined from the environment. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Llama: An instance of the Llama class with the loaded model and tokenizer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If there are no checkpoint files in the specified directory,
|
||||||
|
or if the model parallel size does not match the number of checkpoint files.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
|
and loads the pre-trained model and tokenizer.
|
||||||
|
"""
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group("nccl")
|
||||||
|
if not model_parallel_is_initialized():
|
||||||
|
if model_parallel_size is None:
|
||||||
|
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
|
# seed must be the same in all processes
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if local_rank > 0:
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
assert model_parallel_size == len(
|
||||||
|
checkpoints
|
||||||
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||||
|
assert (
|
||||||
|
model_args.vocab_size == tokenizer.n_words
|
||||||
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
|
||||||
|
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
|
||||||
|
model = Transformer(model_args)
|
||||||
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
|
||||||
|
print("ffn_quantize_mode: ", ffn_quantize_mode)
|
||||||
|
if ffn_quantize_mode == FfnQuantizeMode.FP8_ROWWISE:
|
||||||
|
# Move weights to GPU with quantization
|
||||||
|
if model_load_mode == ModelLoadMode.FP8:
|
||||||
|
fp8_scales_path = os.path.join(
|
||||||
|
ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||||
|
)
|
||||||
|
assert os.path.isfile(
|
||||||
|
fp8_scales_path
|
||||||
|
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||||
|
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||||
|
|
||||||
|
for block in model.layers:
|
||||||
|
if isinstance(block, TransformerBlock):
|
||||||
|
if block.layer_id == 0 or block.layer_id == (
|
||||||
|
model.n_layers - 1
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
block.feed_forward.w1.weight = load_fp8(
|
||||||
|
block.feed_forward.w1.weight,
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||||
|
],
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
)
|
||||||
|
block.feed_forward.w3.weight = load_fp8(
|
||||||
|
block.feed_forward.w3.weight,
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||||
|
],
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
)
|
||||||
|
block.feed_forward.w2.weight = load_fp8(
|
||||||
|
block.feed_forward.w2.weight,
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||||
|
],
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for block in model.layers:
|
||||||
|
if isinstance(block, TransformerBlock):
|
||||||
|
if block.layer_id == 0 or block.layer_id == (
|
||||||
|
model.n_layers - 1
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
block.feed_forward.w1.weight = quantize_fp8(
|
||||||
|
block.feed_forward.w1.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
ffn_quantize_mode,
|
||||||
|
output_device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
block.feed_forward.w3.weight = quantize_fp8(
|
||||||
|
block.feed_forward.w3.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
ffn_quantize_mode,
|
||||||
|
output_device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
block.feed_forward.w2.weight = quantize_fp8(
|
||||||
|
block.feed_forward.w2.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
ffn_quantize_mode,
|
||||||
|
output_device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, parameter in model.named_parameters():
|
||||||
|
if not isinstance(parameter, Fp8ScaledWeights):
|
||||||
|
parameter.data = parameter.to(device="cuda")
|
||||||
|
else:
|
||||||
|
for _, parameter in model.named_parameters():
|
||||||
|
parameter.data = parameter.to(device="cuda")
|
||||||
|
|
||||||
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
|
return Llama(model, tokenizer, model_args)
|
||||||
|
|
||||||
|
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||||
|
self.args = args
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
model_inputs: List[ModelInput],
|
||||||
|
max_gen_len: int,
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
include_stop_token: bool = False,
|
||||||
|
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
|
||||||
|
"""
|
||||||
|
Generate text sequences based on provided prompts using the language generation model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
|
||||||
|
max_gen_len (int): Maximum length of the generated text sequence.
|
||||||
|
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
||||||
|
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
||||||
|
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
|
||||||
|
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
|
||||||
|
If logprobs is True, token log probabilities are computed for each generated token.
|
||||||
|
|
||||||
|
"""
|
||||||
|
params = self.model.params
|
||||||
|
prompt_tokens = [m.tokens for m in model_inputs]
|
||||||
|
bsz = len(prompt_tokens)
|
||||||
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||||
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
assert max_prompt_len <= params.max_seq_len
|
||||||
|
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
|
||||||
|
|
||||||
|
pad_id = self.tokenizer.pad_id
|
||||||
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||||
|
for k, t in enumerate(prompt_tokens):
|
||||||
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||||
|
|
||||||
|
prev_pos = 0
|
||||||
|
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||||
|
input_text_mask = tokens != pad_id
|
||||||
|
if min_prompt_len == total_len:
|
||||||
|
logits = self.model.forward(tokens, prev_pos)
|
||||||
|
token_logprobs = -F.cross_entropy(
|
||||||
|
input=logits.transpose(1, 2),
|
||||||
|
target=tokens,
|
||||||
|
reduction="none",
|
||||||
|
ignore_index=pad_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
|
||||||
|
|
||||||
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
|
||||||
|
if temperature > 0:
|
||||||
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||||
|
|
||||||
|
next_token = next_token.reshape(-1)
|
||||||
|
# only replace token if prompt has already been generated
|
||||||
|
next_token = torch.where(
|
||||||
|
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||||
|
)
|
||||||
|
tokens[:, cur_pos] = next_token
|
||||||
|
|
||||||
|
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||||
|
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||||
|
input=logits.transpose(1, 2),
|
||||||
|
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
||||||
|
reduction="none",
|
||||||
|
ignore_index=pad_id,
|
||||||
|
)
|
||||||
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
||||||
|
torch.isin(next_token, stop_tokens)
|
||||||
|
)
|
||||||
|
prev_pos = cur_pos
|
||||||
|
if all(eos_reached):
|
||||||
|
break
|
||||||
|
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs = token_logprobs.tolist()
|
||||||
|
out_tokens, out_logprobs = [], []
|
||||||
|
for i, toks in enumerate(tokens.tolist()):
|
||||||
|
# cut to max gen len
|
||||||
|
start = 0 if echo else len(prompt_tokens[i])
|
||||||
|
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
|
||||||
|
probs = None
|
||||||
|
if logprobs:
|
||||||
|
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
|
||||||
|
# cut to after eos tok if any
|
||||||
|
for stop_token in self.tokenizer.stop_tokens:
|
||||||
|
try:
|
||||||
|
eos_idx = toks.index(stop_token)
|
||||||
|
if include_stop_token:
|
||||||
|
eos_idx += 1
|
||||||
|
toks = toks[:eos_idx]
|
||||||
|
probs = probs[:eos_idx] if logprobs else None
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
out_tokens.append(toks)
|
||||||
|
out_logprobs.append(probs)
|
||||||
|
return (out_tokens, out_logprobs if logprobs else None)
|
||||||
|
|
||||||
|
def text_completion(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
) -> List[CompletionPrediction]:
|
||||||
|
"""
|
||||||
|
Perform text completion for a list of prompts using the language generation model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts (List[str]): List of text prompts for completion.
|
||||||
|
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
||||||
|
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
||||||
|
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
|
||||||
|
If not provided, it's set to the model's maximum sequence length minus 1.
|
||||||
|
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
|
||||||
|
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
|
||||||
|
If logprobs is True, token log probabilities are computed for each generated token.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if max_gen_len is None:
|
||||||
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
|
||||||
|
generation_tokens, generation_logprobs = self.generate(
|
||||||
|
model_inputs=[ModelInput(tokens=pt) for pt in prompt_tokens],
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
logprobs=logprobs,
|
||||||
|
echo=echo,
|
||||||
|
)
|
||||||
|
if logprobs:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"generation": self.tokenizer.decode(t),
|
||||||
|
"tokens": [self.tokenizer.decode([x]) for x in t],
|
||||||
|
"logprobs": logprobs_i,
|
||||||
|
}
|
||||||
|
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
|
||||||
|
]
|
||||||
|
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
dialogs: List[Dialog],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
) -> List[ChatPrediction]:
|
||||||
|
"""
|
||||||
|
Generate assistant responses for a list of conversational dialogs using the language generation model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
|
||||||
|
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
||||||
|
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
||||||
|
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
|
||||||
|
If not provided, it's set to the model's maximum sequence length minus 1.
|
||||||
|
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method generates assistant responses for the provided conversational dialogs.
|
||||||
|
It employs nucleus sampling to introduce controlled randomness in text generation.
|
||||||
|
If logprobs is True, token log probabilities are computed for each generated token.
|
||||||
|
"""
|
||||||
|
if max_gen_len is None:
|
||||||
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
|
model_inputs = [
|
||||||
|
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
|
||||||
|
]
|
||||||
|
generation_tokens, generation_logprobs = self.generate(
|
||||||
|
model_inputs=model_inputs,
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
logprobs=logprobs,
|
||||||
|
include_stop_token=True,
|
||||||
|
)
|
||||||
|
if logprobs:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"generation": self.formatter.decode_assistant_message(t),
|
||||||
|
"tokens": [self.tokenizer.decode([x]) for x in t],
|
||||||
|
"logprobs": logprobs_i,
|
||||||
|
}
|
||||||
|
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"generation": self.formatter.decode_assistant_message(t),
|
||||||
|
}
|
||||||
|
for t in generation_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs, p):
|
||||||
|
"""
|
||||||
|
Perform top-p (nucleus) sampling on a probability distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Probability distribution tensor.
|
||||||
|
p (float): Probability threshold for top-p sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled token indices.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||||
|
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||||
|
"""
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
355
toolchain/inference/quantization/model.py
Normal file
355
toolchain/inference/quantization/model.py
Normal file
|
@ -0,0 +1,355 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairscale.nn.model_parallel.layers import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
|
from fp8.fp8_impls import ffn_swiglu
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
dim: int = 4096
|
||||||
|
n_layers: int = 32
|
||||||
|
n_heads: int = 32
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
vocab_size: int = -1
|
||||||
|
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||||
|
ffn_dim_multiplier: Optional[float] = None
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 500000
|
||||||
|
use_scaled_rope: bool = False
|
||||||
|
|
||||||
|
max_batch_size: int = 32
|
||||||
|
max_seq_len: int = 2048
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if hasattr(self, k):
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
if self.n_kv_heads is None:
|
||||||
|
self.n_kv_heads = self.n_heads
|
||||||
|
assert self.n_kv_heads <= self.n_heads
|
||||||
|
assert self.n_heads % self.n_kv_heads == 0
|
||||||
|
assert self.dim % self.n_heads == 0
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
return output * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def apply_scaling(freqs: torch.Tensor):
|
||||||
|
# Values obtained from grid search
|
||||||
|
scale_factor = 8
|
||||||
|
low_freq_factor = 1
|
||||||
|
high_freq_factor = 4
|
||||||
|
old_context_len = 8192 # original llama3 length
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
new_freqs = []
|
||||||
|
for freq in freqs:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
if wavelen < high_freq_wavelen:
|
||||||
|
new_freqs.append(freq)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
new_freqs.append(freq / scale_factor)
|
||||||
|
else:
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
smooth = (old_context_len / wavelen - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor
|
||||||
|
)
|
||||||
|
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
|
||||||
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(
|
||||||
|
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
|
||||||
|
):
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||||
|
if use_scaled:
|
||||||
|
freqs = apply_scaling(freqs)
|
||||||
|
freqs = torch.outer(t, freqs)
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
xq: torch.Tensor,
|
||||||
|
xk: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||||
|
bs, slen, n_kv_heads, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, :, None, :]
|
||||||
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||||
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
|
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||||
|
self.n_local_heads = args.n_heads // model_parallel_size
|
||||||
|
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||||
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
|
||||||
|
self.wq = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wk = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
self.n_kv_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wv = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
self.n_kv_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wo = RowParallelLinear(
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
args.dim,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache_k = torch.zeros(
|
||||||
|
(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
).cuda()
|
||||||
|
self.cache_v = torch.zeros(
|
||||||
|
(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
bsz, seqlen, _ = x.shape
|
||||||
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
|
||||||
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
|
self.cache_k = self.cache_k.to(xq)
|
||||||
|
self.cache_v = self.cache_v.to(xq)
|
||||||
|
|
||||||
|
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||||
|
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||||
|
|
||||||
|
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||||
|
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
keys = repeat_kv(
|
||||||
|
keys, self.n_rep
|
||||||
|
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||||
|
values = repeat_kv(
|
||||||
|
values, self.n_rep
|
||||||
|
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||||
|
|
||||||
|
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||||
|
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||||
|
values = values.transpose(
|
||||||
|
1, 2
|
||||||
|
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||||
|
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||||
|
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||||
|
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
||||||
|
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||||
|
return self.wo(output)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.w1 = ColumnParallelLinear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.w3 = ColumnParallelLinear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.w2 = RowParallelLinear(
|
||||||
|
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||||
|
return reduce_from_model_parallel_region(out)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id: int, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = args.n_heads
|
||||||
|
self.dim = args.dim
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = FeedForward(
|
||||||
|
dim=args.dim,
|
||||||
|
hidden_dim=4 * args.dim,
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
)
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, params: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.params = params
|
||||||
|
self.vocab_size = params.vocab_size
|
||||||
|
self.n_layers = params.n_layers
|
||||||
|
|
||||||
|
self.tok_embeddings = VocabParallelEmbedding(
|
||||||
|
params.vocab_size, params.dim, init_method=lambda x: x
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for layer_id in range(params.n_layers):
|
||||||
|
self.layers.append(TransformerBlock(layer_id, params))
|
||||||
|
|
||||||
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||||
|
self.output = ColumnParallelLinear(
|
||||||
|
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
|
||||||
|
)
|
||||||
|
|
||||||
|
self.freqs_cis = precompute_freqs_cis(
|
||||||
|
params.dim // params.n_heads,
|
||||||
|
params.max_seq_len * 2,
|
||||||
|
params.rope_theta,
|
||||||
|
params.use_scaled_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||||
|
_bsz, seqlen = tokens.shape
|
||||||
|
h = self.tok_embeddings(tokens)
|
||||||
|
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||||
|
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if seqlen > 1:
|
||||||
|
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||||
|
|
||||||
|
mask = torch.triu(mask, diagonal=1)
|
||||||
|
|
||||||
|
# When performing key-value caching, we compute the attention scores
|
||||||
|
# only for the new sequence. Thus, the matrix of scores is of size
|
||||||
|
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||||
|
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||||
|
mask = torch.hstack(
|
||||||
|
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
|
||||||
|
).type_as(h)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
|
h = self.norm(h)
|
||||||
|
output = self.output(h).float()
|
||||||
|
return output
|
155
toolchain/inference/quantization/quantize_checkpoint.py
Normal file
155
toolchain/inference/quantization/quantize_checkpoint.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
get_model_parallel_rank,
|
||||||
|
initialize_model_parallel,
|
||||||
|
model_parallel_is_initialized,
|
||||||
|
)
|
||||||
|
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
|
||||||
|
|
||||||
|
from llama.model import ModelArgs, Transformer, TransformerBlock
|
||||||
|
from llama.tokenizer import Tokenizer
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
ckpt_dir: str,
|
||||||
|
tokenizer_path: str,
|
||||||
|
quantized_ckpt_dir: str,
|
||||||
|
max_seq_len: Optional[int] = 512,
|
||||||
|
max_batch_size: Optional[int] = 4,
|
||||||
|
model_parallel_size: Optional[int] = None,
|
||||||
|
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
|
||||||
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
|
seed: int = 1,
|
||||||
|
):
|
||||||
|
""" """
|
||||||
|
if not os.path.exists(quantized_ckpt_dir):
|
||||||
|
os.makedirs(quantized_ckpt_dir)
|
||||||
|
shutil.copy(
|
||||||
|
os.path.join(ckpt_dir, "params.json"),
|
||||||
|
os.path.join(quantized_ckpt_dir, "params.json"),
|
||||||
|
)
|
||||||
|
shutil.copy(
|
||||||
|
os.path.join(ckpt_dir, "tokenizer.model"),
|
||||||
|
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group("nccl")
|
||||||
|
if not model_parallel_is_initialized():
|
||||||
|
if model_parallel_size is None:
|
||||||
|
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
|
# seed must be the same in all processes
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if local_rank > 0:
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
assert model_parallel_size == len(
|
||||||
|
checkpoints
|
||||||
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||||
|
assert (
|
||||||
|
model_args.vocab_size == tokenizer.n_words
|
||||||
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
|
||||||
|
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
|
||||||
|
model = Transformer(model_args)
|
||||||
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
|
||||||
|
print(ckpt_path)
|
||||||
|
assert (
|
||||||
|
quantized_ckpt_dir is not None
|
||||||
|
), "QUantized checkpoint directory should not be None"
|
||||||
|
fp8_scales = {}
|
||||||
|
for block in model.layers:
|
||||||
|
if isinstance(block, TransformerBlock):
|
||||||
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
|
continue
|
||||||
|
|
||||||
|
fp8_weight = quantize_fp8(
|
||||||
|
block.feed_forward.w1.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
ffn_quantize_mode,
|
||||||
|
output_device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
with torch.inference_mode():
|
||||||
|
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
||||||
|
] = fp8_weight.scale
|
||||||
|
|
||||||
|
fp8_weight = quantize_fp8(
|
||||||
|
block.feed_forward.w3.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
ffn_quantize_mode,
|
||||||
|
output_device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
with torch.inference_mode():
|
||||||
|
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
||||||
|
] = fp8_weight.scale
|
||||||
|
|
||||||
|
fp8_weight = quantize_fp8(
|
||||||
|
block.feed_forward.w2.weight,
|
||||||
|
fp8_activation_scale_ub,
|
||||||
|
ffn_quantize_mode,
|
||||||
|
output_device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
with torch.inference_mode():
|
||||||
|
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
||||||
|
fp8_scales[
|
||||||
|
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
||||||
|
] = fp8_weight.scale
|
||||||
|
|
||||||
|
fp8_scales_path = os.path.join(
|
||||||
|
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||||
|
)
|
||||||
|
torch.save(fp8_scales, fp8_scales_path)
|
||||||
|
|
||||||
|
ckpt_path = os.path.join(
|
||||||
|
quantized_ckpt_dir,
|
||||||
|
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
|
||||||
|
)
|
||||||
|
torch.save(model.state_dict(), ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
25
toolchain/inference/quantization/run_quantize_checkpoint.sh
Executable file
25
toolchain/inference/quantization/run_quantize_checkpoint.sh
Executable file
|
@ -0,0 +1,25 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
set -x
|
||||||
|
|
||||||
|
cd $(git rev-parse --show-toplevel)
|
||||||
|
|
||||||
|
MASTER_HOST=$1
|
||||||
|
RUN_ID=$2
|
||||||
|
CKPT_DIR=$3
|
||||||
|
QUANT_CKPT_DIR=$4
|
||||||
|
TOKENIZER_PATH=$5
|
||||||
|
NNODES=$6
|
||||||
|
NPROC=$7
|
||||||
|
|
||||||
|
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||||
|
|
||||||
|
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
|
||||||
|
torchrun \
|
||||||
|
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||||
|
--rdzv_id=$RUN_ID \
|
||||||
|
--rdzv_conf='timeout=120' \
|
||||||
|
--rdzv_backend=c10d \
|
||||||
|
--rdzv_endpoint="${MASTER_HOST}:29502" \
|
||||||
|
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR
|
102
toolchain/inference/quantization/test_fp8.py
Normal file
102
toolchain/inference/quantization/test_fp8.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8
|
||||||
|
from hypothesis import given, settings, strategies as st
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not torch.cuda.is_available()
|
||||||
|
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||||
|
"Skip when H100 is not available",
|
||||||
|
)
|
||||||
|
class FP8Tests(unittest.TestCase):
|
||||||
|
@settings(deadline=None)
|
||||||
|
@given(
|
||||||
|
D=st.sampled_from([4096, 8192]),
|
||||||
|
HD_L=st.sampled_from([1280, 2560]),
|
||||||
|
B=st.sampled_from([1, 2]),
|
||||||
|
T=st.sampled_from([2048, 4096]),
|
||||||
|
UB=st.sampled_from([1000, 10000]),
|
||||||
|
)
|
||||||
|
def test_fp8_ffn(
|
||||||
|
self,
|
||||||
|
D: int,
|
||||||
|
HD_L: int,
|
||||||
|
B: int,
|
||||||
|
T: int,
|
||||||
|
UB: float,
|
||||||
|
) -> None:
|
||||||
|
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||||
|
w13 = (
|
||||||
|
torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||||
|
)
|
||||||
|
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||||
|
|
||||||
|
x_q = quantize_fp8(x, UB)
|
||||||
|
w13_q = quantize_fp8(w13, UB)
|
||||||
|
w2_q = quantize_fp8(w2, UB)
|
||||||
|
|
||||||
|
def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor:
|
||||||
|
(B, T, D) = x.shape
|
||||||
|
(HD_L_2, D_) = w13.shape
|
||||||
|
assert D_ == D
|
||||||
|
HD_L = HD_L_2 // 2
|
||||||
|
|
||||||
|
y = x.view(B * T, D) @ w13.T
|
||||||
|
x1 = y[:, :HD_L]
|
||||||
|
x2 = y[:, HD_L:]
|
||||||
|
|
||||||
|
z = torch.nn.functional.silu(x1) * x2
|
||||||
|
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||||
|
|
||||||
|
v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q)
|
||||||
|
|
||||||
|
# Fake quant
|
||||||
|
x = x_q.weight.bfloat16() * x_q.scale
|
||||||
|
w13 = w13_q.weight.bfloat16() * w13_q.scale
|
||||||
|
w2 = w2_q.weight.bfloat16() * w2_q.scale
|
||||||
|
|
||||||
|
v_ref = ref_ffn(x, w13, w2)
|
||||||
|
|
||||||
|
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||||
|
|
||||||
|
@settings(deadline=None)
|
||||||
|
@given(
|
||||||
|
B_T=st.sampled_from([2048, 4096]),
|
||||||
|
D=st.sampled_from([128, 256]),
|
||||||
|
HD_L=st.sampled_from([256, 512]),
|
||||||
|
UB=st.sampled_from([1000, 10000]),
|
||||||
|
)
|
||||||
|
def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None:
|
||||||
|
B_T = 4096
|
||||||
|
D = 256
|
||||||
|
HD_L = 512
|
||||||
|
UB = float(UB)
|
||||||
|
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||||
|
wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||||
|
|
||||||
|
x_q = quantize_fp8(x, UB)
|
||||||
|
wqkv_q = quantize_fp8(wqkv, UB)
|
||||||
|
|
||||||
|
num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda")
|
||||||
|
|
||||||
|
y = attn_linear(x, wqkv_q)
|
||||||
|
y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens)
|
||||||
|
|
||||||
|
# Fake quant
|
||||||
|
x = x_q.weight.bfloat16() * x_q.scale
|
||||||
|
wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale
|
||||||
|
y_ref = (x @ wqkv.T).to(torch.bfloat16)
|
||||||
|
|
||||||
|
torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3)
|
||||||
|
torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
117
toolchain/inference/server.py
Normal file
117
toolchain/inference/server.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from toolchain.utils import get_config_dir, parse_config
|
||||||
|
from .api.config import ModelInferenceHydraConfig
|
||||||
|
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
|
||||||
|
|
||||||
|
from .api_instance import get_inference_api_instance
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
GLOBAL_CONFIG = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
return GLOBAL_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def handle_sigint(*args, **kwargs):
|
||||||
|
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for task in asyncio.all_tasks(loop):
|
||||||
|
task.cancel()
|
||||||
|
loop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup():
|
||||||
|
global InferenceApiInstance
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
hydra_config = ModelInferenceHydraConfig(
|
||||||
|
**OmegaConf.to_container(config["model_inference_config"], resolve=True)
|
||||||
|
)
|
||||||
|
model_inference_config = hydra_config.convert_to_model_inferene_config()
|
||||||
|
|
||||||
|
InferenceApiInstance = await get_inference_api_instance(
|
||||||
|
model_inference_config,
|
||||||
|
)
|
||||||
|
await InferenceApiInstance.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown():
|
||||||
|
global InferenceApiInstance
|
||||||
|
|
||||||
|
print("shutting down")
|
||||||
|
await InferenceApiInstance.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
# there's a single model parallel process running serving the model. for now,
|
||||||
|
# we don't support multiple concurrent requests to this process.
|
||||||
|
semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/inference/chat_completion", response_model=ChatCompletionResponseStreamChunk
|
||||||
|
)
|
||||||
|
def chat_completion(request: Request, exec_request: ChatCompletionRequest):
|
||||||
|
if semaphore.locked():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Only a single concurrent request allowed right now.",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def sse_generator(event_gen):
|
||||||
|
try:
|
||||||
|
async for event in event_gen:
|
||||||
|
yield f"data: {event.json()}\n\n"
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print("Generator cancelled")
|
||||||
|
await event_gen.aclose()
|
||||||
|
finally:
|
||||||
|
semaphore.release()
|
||||||
|
|
||||||
|
async def event_gen():
|
||||||
|
async for event in InferenceApiInstance.chat_completion(exec_request):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
sse_generator(event_gen()),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
global GLOBAL_CONFIG
|
||||||
|
config_dir = get_config_dir()
|
||||||
|
GLOBAL_CONFIG = parse_config(config_dir, config_path)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# FYI this does not do hot-reloads
|
||||||
|
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||||
|
print(f"Listening on {listen_host}:{port}")
|
||||||
|
uvicorn.run(app, host=listen_host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
2
toolchain/memory/api/__init__.py
Normal file
2
toolchain/memory/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .datatypes import * # noqa: F401 F403
|
||||||
|
from .endpoints import * # noqa: F401 F403
|
19
toolchain/memory/api/datatypes.py
Normal file
19
toolchain/memory/api/datatypes.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBank(BaseModel):
|
||||||
|
memory_bank_id: str
|
||||||
|
memory_bank_name: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBankDocument(BaseModel):
|
||||||
|
document_id: str
|
||||||
|
content: bytes
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
mime_type: str
|
55
toolchain/memory/api/endpoints.py
Normal file
55
toolchain/memory/api/endpoints.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
from typing import List, Protocol
|
||||||
|
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBanks(Protocol):
|
||||||
|
@webmethod(route="/memory_banks/create")
|
||||||
|
def post_create_memory_bank(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
bank_name: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/list")
|
||||||
|
def get_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/get")
|
||||||
|
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/drop")
|
||||||
|
def delete_memory_bank(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
) -> str: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/insert")
|
||||||
|
def post_insert_memory_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/update")
|
||||||
|
def post_update_memory_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/get")
|
||||||
|
def get_memory_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
document_uuids: List[str],
|
||||||
|
) -> List[MemoryBankDocument]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/delete")
|
||||||
|
def delete_memory_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
document_uuids: List[str],
|
||||||
|
) -> List[str]: ...
|
2
toolchain/post_training/api/__init__.py
Normal file
2
toolchain/post_training/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .datatypes import * # noqa: F401 F403
|
||||||
|
from .endpoints import * # noqa: F401 F403
|
88
toolchain/post_training/api/datatypes.py
Normal file
88
toolchain/post_training/api/datatypes.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerType(Enum):
|
||||||
|
adam = "adam"
|
||||||
|
adamw = "adamw"
|
||||||
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OptimizerConfig(BaseModel):
|
||||||
|
optimizer_type: OptimizerType
|
||||||
|
lr: float
|
||||||
|
lr_min: float
|
||||||
|
weight_decay: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
n_epochs: int
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
n_iters: int
|
||||||
|
|
||||||
|
enable_activation_checkpointing: bool
|
||||||
|
memory_efficient_fsdp_wrap: bool
|
||||||
|
fsdp_cpu_offload: bool
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FinetuningAlgorithm(Enum):
|
||||||
|
full = "full"
|
||||||
|
lora = "lora"
|
||||||
|
qlora = "qlora"
|
||||||
|
dora = "dora"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LoraFinetuningConfig(BaseModel):
|
||||||
|
lora_attn_modules: List[str]
|
||||||
|
apply_lora_to_mlp: bool
|
||||||
|
apply_lora_to_output: bool
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobLogStream(BaseModel):
|
||||||
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
log_lines: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobStatus(Enum):
|
||||||
|
running = "running"
|
||||||
|
completed = "completed"
|
||||||
|
failed = "failed"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RLHFAlgorithm(Enum):
|
||||||
|
dpo = "dpo"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DPOAlignmentConfig(BaseModel):
|
||||||
|
reward_scale: float
|
||||||
|
reward_clip: float
|
||||||
|
epsilon: float
|
||||||
|
gamma: float
|
123
toolchain/post_training/api/endpoints.py
Normal file
123
toolchain/post_training/api/endpoints.py
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
from models.llama3.datatypes import * # noqa: F403
|
||||||
|
from toolchain.dataset.api.datatypes import * # noqa: F403
|
||||||
|
from toolchain.common.training_types import * # noqa: F403
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingSFTRequest(BaseModel):
|
||||||
|
"""Request to finetune a model."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
model: PretrainedModel
|
||||||
|
dataset: TrainEvalDataset
|
||||||
|
validation_dataset: TrainEvalDataset
|
||||||
|
|
||||||
|
algorithm: FinetuningAlgorithm
|
||||||
|
algorithm_config: Union[
|
||||||
|
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer_config: OptimizerConfig
|
||||||
|
training_config: TrainingConfig
|
||||||
|
|
||||||
|
# TODO: define these
|
||||||
|
hyperparam_search_config: Dict[str, Any]
|
||||||
|
logger_config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingRLHFRequest(BaseModel):
|
||||||
|
"""Request to finetune a model."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
finetuned_model: URL
|
||||||
|
|
||||||
|
dataset: TrainEvalDataset
|
||||||
|
validation_dataset: TrainEvalDataset
|
||||||
|
|
||||||
|
algorithm: RLHFAlgorithm
|
||||||
|
algorithm_config: Union[DPOAlignmentConfig]
|
||||||
|
|
||||||
|
optimizer_config: OptimizerConfig
|
||||||
|
training_config: TrainingConfig
|
||||||
|
|
||||||
|
# TODO: define these
|
||||||
|
hyperparam_search_config: Dict[str, Any]
|
||||||
|
logger_config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class PostTrainingJob(BaseModel):
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobStatusResponse(BaseModel):
|
||||||
|
"""Status of a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
status: PostTrainingJobStatus
|
||||||
|
|
||||||
|
scheduled_at: Optional[datetime] = None
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
resources_allocated: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
|
"""Artifacts of a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# TODO(ashwin): metrics, evals
|
||||||
|
|
||||||
|
|
||||||
|
class PostTraining(Protocol):
|
||||||
|
@webmethod(route="/post_training/supervised_fine_tune")
|
||||||
|
def post_supervised_fine_tune(
|
||||||
|
self,
|
||||||
|
request: PostTrainingSFTRequest,
|
||||||
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/post_training/preference_optimize")
|
||||||
|
def post_preference_optimize(
|
||||||
|
self,
|
||||||
|
request: PostTrainingRLHFRequest,
|
||||||
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/post_training/jobs")
|
||||||
|
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||||
|
|
||||||
|
# sends SSE stream of logs
|
||||||
|
@webmethod(route="/post_training/job/logs")
|
||||||
|
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||||
|
|
||||||
|
@webmethod(route="/post_training/job/status")
|
||||||
|
def get_training_job_status(
|
||||||
|
self, job_uuid: str
|
||||||
|
) -> PostTrainingJobStatusResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/post_training/job/cancel")
|
||||||
|
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/post_training/job/artifacts")
|
||||||
|
def get_training_job_artifacts(
|
||||||
|
self, job_uuid: str
|
||||||
|
) -> PostTrainingJobArtifactsResponse: ...
|
2
toolchain/reward_scoring/api/__init__.py
Normal file
2
toolchain/reward_scoring/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .datatypes import * # noqa: F401 F403
|
||||||
|
from .endpoints import * # noqa: F401 F403
|
25
toolchain/reward_scoring/api/datatypes.py
Normal file
25
toolchain/reward_scoring/api/datatypes.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
from models.llama3.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoredMessage(BaseModel):
|
||||||
|
message: Message
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DialogGenerations(BaseModel):
|
||||||
|
dialog: List[Message]
|
||||||
|
sampled_generations: List[Message]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoredDialogGenerations(BaseModel):
|
||||||
|
dialog: List[Message]
|
||||||
|
scored_generations: List[ScoredMessage]
|
27
toolchain/reward_scoring/api/endpoints.py
Normal file
27
toolchain/reward_scoring/api/endpoints.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
from typing import List, Protocol, Union
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RewardScoringRequest(BaseModel):
|
||||||
|
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||||
|
|
||||||
|
dialog_generations: List[DialogGenerations]
|
||||||
|
model: RewardModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RewardScoringResponse(BaseModel):
|
||||||
|
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
|
scored_generations: List[ScoredDialogGenerations]
|
||||||
|
|
||||||
|
|
||||||
|
class RewardScoring(Protocol):
|
||||||
|
@webmethod(route="/reward_scoring/score")
|
||||||
|
def post_score(
|
||||||
|
self,
|
||||||
|
request: RewardScoringRequest,
|
||||||
|
) -> Union[RewardScoringResponse]: ...
|
0
toolchain/safety/__init__.py
Normal file
0
toolchain/safety/__init__.py
Normal file
0
toolchain/safety/api/__init__.py
Normal file
0
toolchain/safety/api/__init__.py
Normal file
19
toolchain/safety/api/config.py
Normal file
19
toolchain/safety/api/config.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaGuardShieldConfig(BaseModel):
|
||||||
|
model_dir: str
|
||||||
|
excluded_categories: List[str]
|
||||||
|
disable_input_check: bool = False
|
||||||
|
disable_output_check: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PromptGuardShieldConfig(BaseModel):
|
||||||
|
model_dir: str
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyConfig(BaseModel):
|
||||||
|
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||||
|
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None
|
53
toolchain/safety/api/datatypes.py
Normal file
53
toolchain/safety/api/datatypes.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from models.llama3.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
from toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BuiltinShield(Enum):
|
||||||
|
llama_guard = "llama_guard"
|
||||||
|
prompt_guard = "prompt_guard"
|
||||||
|
code_scanner_guard = "code_scanner_guard"
|
||||||
|
third_party_shield = "third_party_shield"
|
||||||
|
|
||||||
|
|
||||||
|
ShieldType = Union[BuiltinShield, str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OnViolationAction(Enum):
|
||||||
|
IGNORE = 0
|
||||||
|
WARN = 1
|
||||||
|
RAISE = 2
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ShieldDefinition(BaseModel):
|
||||||
|
shield_type: ShieldType
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE
|
||||||
|
execution_config: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ShieldCall(BaseModel):
|
||||||
|
call_id: str
|
||||||
|
shield_type: ShieldType
|
||||||
|
arguments: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ShieldResponse(BaseModel):
|
||||||
|
shield_type: ShieldType
|
||||||
|
# TODO(ashwin): clean this up
|
||||||
|
is_violation: bool
|
||||||
|
violation_type: Optional[str] = None
|
||||||
|
violation_return_message: Optional[str] = None
|
25
toolchain/safety/shields/__init__.py
Normal file
25
toolchain/safety/shields/__init__.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# supress warnings and spew of logs from hugging face
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from .base import ( # noqa: F401
|
||||||
|
DummyShield,
|
||||||
|
OnViolationAction,
|
||||||
|
ShieldBase,
|
||||||
|
ShieldResponse,
|
||||||
|
TextShield,
|
||||||
|
)
|
||||||
|
from .code_scanner import CodeScannerShield # noqa: F401
|
||||||
|
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||||
|
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||||
|
from .prompt_guard import PromptGuardShield # noqa: F401
|
||||||
|
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
||||||
|
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
65
toolchain/safety/shields/base.py
Normal file
65
toolchain/safety/shields/base.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from models.llama3.datatypes import Attachment, Message
|
||||||
|
from toolchain.safety.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldBase(ABC):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
):
|
||||||
|
self.on_violation_action = on_violation_action
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_shield_type(self) -> ShieldType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def message_content_as_str(message: Message) -> str:
|
||||||
|
def _to_str(content: Union[str, Attachment]) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
elif isinstance(content, Attachment):
|
||||||
|
return f"File: {str(content.url)}"
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(message.content, list) or isinstance(message.content, tuple):
|
||||||
|
return "\n".join([_to_str(c) for c in message.content])
|
||||||
|
else:
|
||||||
|
return _to_str(message.content)
|
||||||
|
|
||||||
|
|
||||||
|
# For shields that operate on simple strings
|
||||||
|
class TextShield(ShieldBase):
|
||||||
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||||
|
return "\n".join([message_content_as_str(m) for m in messages])
|
||||||
|
|
||||||
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
|
text = self.convert_messages_to_text(messages)
|
||||||
|
return await self.run_impl(text)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyShield(TextShield):
|
||||||
|
|
||||||
|
def get_shield_type(self) -> ShieldType:
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
|
# Dummy return LOW to test e2e
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.third_party_shield, is_violation=False
|
||||||
|
)
|
28
toolchain/safety/shields/code_scanner.py
Normal file
28
toolchain/safety/shields/code_scanner.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from codeshield.cs import CodeShield
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .base import ShieldResponse, TextShield
|
||||||
|
from toolchain.safety.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class CodeScannerShield(TextShield):
|
||||||
|
|
||||||
|
def get_shield_type(self) -> ShieldType:
|
||||||
|
return BuiltinShield.code_scanner_guard
|
||||||
|
|
||||||
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
|
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||||
|
result = await CodeShield.scan_code(text)
|
||||||
|
if result.is_insecure:
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.code_scanner_guard,
|
||||||
|
is_violation=True,
|
||||||
|
violation_type=",".join(
|
||||||
|
[issue.pattern_id for issue in result.issues_found]
|
||||||
|
),
|
||||||
|
violation_return_message="Sorry, I found security concerns in the code.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
|
||||||
|
)
|
0
toolchain/safety/shields/contrib/__init__.py
Normal file
0
toolchain/safety/shields/contrib/__init__.py
Normal file
28
toolchain/safety/shields/contrib/third_party_shield.py
Normal file
28
toolchain/safety/shields/contrib/third_party_shield.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from models.llama3.datatypes import Message
|
||||||
|
|
||||||
|
parent_dir = "../.."
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
from toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
|
||||||
|
_INSTANCE = None
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyShield(ShieldBase):
|
||||||
|
@staticmethod
|
||||||
|
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
|
||||||
|
global _INSTANCE
|
||||||
|
if _INSTANCE is None:
|
||||||
|
_INSTANCE = ThirdPartyShield(on_violation_action)
|
||||||
|
return _INSTANCE
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
):
|
||||||
|
super().__init__(on_violation_action)
|
||||||
|
|
||||||
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
|
super.run() # will raise NotImplementedError
|
248
toolchain/safety/shields/llama_guard.py
Normal file
248
toolchain/safety/shields/llama_guard.py
Normal file
|
@ -0,0 +1,248 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
from string import Template
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from models.llama3.datatypes import Message
|
||||||
|
from termcolor import cprint
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
from toolchain.safety.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
SAFE_RESPONSE = "safe"
|
||||||
|
_INSTANCE = None
|
||||||
|
|
||||||
|
CAT_VIOLENT_CRIMES = "Violent Crimes"
|
||||||
|
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
|
||||||
|
CAT_SEX_CRIMES = "Sex Crimes"
|
||||||
|
CAT_CHILD_EXPLOITATION = "Child Exploitation"
|
||||||
|
CAT_DEFAMATION = "Defamation"
|
||||||
|
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
|
||||||
|
CAT_PRIVACY = "Privacy"
|
||||||
|
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
|
||||||
|
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
|
||||||
|
CAT_HATE = "Hate"
|
||||||
|
CAT_SELF_HARM = "Self-Harm"
|
||||||
|
CAT_SEXUAL_CONTENT = "Sexual Content"
|
||||||
|
CAT_ELECTIONS = "Elections"
|
||||||
|
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
|
||||||
|
|
||||||
|
|
||||||
|
SAFETY_CATEGORIES_TO_CODE_MAP = {
|
||||||
|
CAT_VIOLENT_CRIMES: "S1",
|
||||||
|
CAT_NON_VIOLENT_CRIMES: "S2",
|
||||||
|
CAT_SEX_CRIMES: "S3",
|
||||||
|
CAT_CHILD_EXPLOITATION: "S4",
|
||||||
|
CAT_DEFAMATION: "S5",
|
||||||
|
CAT_SPECIALIZED_ADVICE: "S6",
|
||||||
|
CAT_PRIVACY: "S7",
|
||||||
|
CAT_INTELLECTUAL_PROPERTY: "S8",
|
||||||
|
CAT_INDISCRIMINATE_WEAPONS: "S9",
|
||||||
|
CAT_HATE: "S10",
|
||||||
|
CAT_SELF_HARM: "S11",
|
||||||
|
CAT_SEXUAL_CONTENT: "S12",
|
||||||
|
CAT_ELECTIONS: "S13",
|
||||||
|
CAT_CODE_INTERPRETER_ABUSE: "S14",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||||
|
CAT_VIOLENT_CRIMES,
|
||||||
|
CAT_NON_VIOLENT_CRIMES,
|
||||||
|
CAT_SEX_CRIMES,
|
||||||
|
CAT_CHILD_EXPLOITATION,
|
||||||
|
CAT_DEFAMATION,
|
||||||
|
CAT_SPECIALIZED_ADVICE,
|
||||||
|
CAT_PRIVACY,
|
||||||
|
CAT_INTELLECTUAL_PROPERTY,
|
||||||
|
CAT_INDISCRIMINATE_WEAPONS,
|
||||||
|
CAT_HATE,
|
||||||
|
CAT_SELF_HARM,
|
||||||
|
CAT_SEXUAL_CONTENT,
|
||||||
|
CAT_ELECTIONS,
|
||||||
|
CAT_CODE_INTERPRETER_ABUSE,
|
||||||
|
]
|
||||||
|
|
||||||
|
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
|
||||||
|
|
||||||
|
SAFETY_CATEGORIES = """
|
||||||
|
|
||||||
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
||||||
|
$categories
|
||||||
|
<END UNSAFE CONTENT CATEGORIES>
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROMPT_CONVERSATION = """
|
||||||
|
<BEGIN CONVERSATION>
|
||||||
|
|
||||||
|
$conversations
|
||||||
|
|
||||||
|
<END CONVERSATION>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_INSTRUCTIONS = """
|
||||||
|
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
|
||||||
|
- First line must read 'safe' or 'unsafe'.
|
||||||
|
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE = Template(
|
||||||
|
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def instance(
|
||||||
|
on_violation_action=OnViolationAction.RAISE,
|
||||||
|
model_dir: str = None,
|
||||||
|
excluded_categories: List[str] = [],
|
||||||
|
disable_input_check: bool = False,
|
||||||
|
disable_output_check: bool = False,
|
||||||
|
) -> "LlamaGuardShield":
|
||||||
|
global _INSTANCE
|
||||||
|
if _INSTANCE is None:
|
||||||
|
_INSTANCE = LlamaGuardShield(
|
||||||
|
on_violation_action,
|
||||||
|
model_dir,
|
||||||
|
excluded_categories,
|
||||||
|
disable_input_check,
|
||||||
|
disable_output_check,
|
||||||
|
)
|
||||||
|
return _INSTANCE
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
model_dir: str = None,
|
||||||
|
excluded_categories: List[str] = [],
|
||||||
|
disable_input_check: bool = False,
|
||||||
|
disable_output_check: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(on_violation_action)
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||||
|
|
||||||
|
assert len(excluded_categories) == 0 or all(
|
||||||
|
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||||
|
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||||
|
|
||||||
|
self.device = "cuda"
|
||||||
|
self.excluded_categories = excluded_categories
|
||||||
|
self.disable_input_check = disable_input_check
|
||||||
|
self.disable_output_check = disable_output_check
|
||||||
|
|
||||||
|
# load model
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_shield_type(self) -> ShieldType:
|
||||||
|
return BuiltinShield.llama_guard
|
||||||
|
|
||||||
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
|
if match:
|
||||||
|
# extracts the unsafe code
|
||||||
|
extracted = match.group(1)
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_safety_categories(self) -> List[str]:
|
||||||
|
|
||||||
|
excluded_categories = self.excluded_categories
|
||||||
|
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
||||||
|
excluded_categories = []
|
||||||
|
|
||||||
|
categories = []
|
||||||
|
for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES:
|
||||||
|
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
|
||||||
|
if cat_code in excluded_categories:
|
||||||
|
continue
|
||||||
|
categories.append(f"{cat_code}: {cat}.")
|
||||||
|
|
||||||
|
return categories
|
||||||
|
|
||||||
|
def build_prompt(self, messages: List[Message]) -> str:
|
||||||
|
|
||||||
|
categories = self.get_safety_categories()
|
||||||
|
categories_str = "\n".join(categories)
|
||||||
|
conversations_str = "\n\n".join(
|
||||||
|
[f"{m.role.capitalize()}: {m.content}" for m in messages]
|
||||||
|
)
|
||||||
|
return PROMPT_TEMPLATE.substitute(
|
||||||
|
agent_type=messages[-1].role.capitalize(),
|
||||||
|
categories=categories_str,
|
||||||
|
conversations=conversations_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||||
|
if response == SAFE_RESPONSE:
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||||
|
)
|
||||||
|
unsafe_code = self.check_unsafe_response(response)
|
||||||
|
if unsafe_code:
|
||||||
|
unsafe_code_list = unsafe_code.split(",")
|
||||||
|
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||||
|
)
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard,
|
||||||
|
is_violation=True,
|
||||||
|
violation_type=unsafe_code,
|
||||||
|
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unexpected response: {response}")
|
||||||
|
|
||||||
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
|
|
||||||
|
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||||
|
)
|
||||||
|
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard,
|
||||||
|
is_violation=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
prompt = self.build_prompt(messages)
|
||||||
|
llama_guard_input = {
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
|
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||||
|
).to(self.device)
|
||||||
|
prompt_len = input_ids.shape[1]
|
||||||
|
output = self.model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=20,
|
||||||
|
output_scores=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
)
|
||||||
|
generated_tokens = output.sequences[:, prompt_len:]
|
||||||
|
|
||||||
|
response = self.tokenizer.decode(
|
||||||
|
generated_tokens[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = response.strip()
|
||||||
|
shield_response = self.get_shield_response(response)
|
||||||
|
|
||||||
|
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
||||||
|
return shield_response
|
112
toolchain/safety/shields/prompt_guard.py
Normal file
112
toolchain/safety/shields/prompt_guard.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
from enum import auto, Enum
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from models.llama3.datatypes import Message
|
||||||
|
from termcolor import cprint
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||||
|
from toolchain.safety.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class PromptGuardShield(TextShield):
|
||||||
|
|
||||||
|
class Mode(Enum):
|
||||||
|
INJECTION = auto()
|
||||||
|
JAILBREAK = auto()
|
||||||
|
|
||||||
|
_instances = {}
|
||||||
|
_model_cache = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def instance(
|
||||||
|
model_dir: str,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||||
|
on_violation_action=OnViolationAction.RAISE,
|
||||||
|
) -> "PromptGuardShield":
|
||||||
|
action_value = on_violation_action.value
|
||||||
|
key = (model_dir, threshold, temperature, mode, action_value)
|
||||||
|
if key not in PromptGuardShield._instances:
|
||||||
|
PromptGuardShield._instances[key] = PromptGuardShield(
|
||||||
|
model_dir=model_dir,
|
||||||
|
threshold=threshold,
|
||||||
|
temperature=temperature,
|
||||||
|
mode=mode,
|
||||||
|
on_violation_action=on_violation_action,
|
||||||
|
)
|
||||||
|
return PromptGuardShield._instances[key]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir: str,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
):
|
||||||
|
super().__init__(on_violation_action)
|
||||||
|
assert (
|
||||||
|
model_dir is not None
|
||||||
|
), "Must provide a model directory for prompt injection shield"
|
||||||
|
if temperature <= 0:
|
||||||
|
raise ValueError("Temperature must be greater than 0")
|
||||||
|
self.device = "cuda"
|
||||||
|
if PromptGuardShield._model_cache is None:
|
||||||
|
# load model and tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_dir, device_map=self.device
|
||||||
|
)
|
||||||
|
PromptGuardShield._model_cache = (tokenizer, model)
|
||||||
|
|
||||||
|
self.tokenizer, self.model = PromptGuardShield._model_cache
|
||||||
|
self.temperature = temperature
|
||||||
|
self.threshold = threshold
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def get_shield_type(self) -> ShieldType:
|
||||||
|
return BuiltinShield.prompt_guard
|
||||||
|
|
||||||
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||||
|
return message_content_as_str(messages[-1])
|
||||||
|
|
||||||
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
|
# run model on messages and return response
|
||||||
|
inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
|
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
logits = outputs[0]
|
||||||
|
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||||
|
score_embedded = probabilities[0, 1].item()
|
||||||
|
score_malicious = probabilities[0, 2].item()
|
||||||
|
cprint(
|
||||||
|
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||||
|
color="magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.mode == self.Mode.INJECTION and (
|
||||||
|
score_embedded + score_malicious > self.threshold
|
||||||
|
):
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.prompt_guard,
|
||||||
|
is_violation=True,
|
||||||
|
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||||
|
violation_return_message="Sorry, I cannot do this.",
|
||||||
|
)
|
||||||
|
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.prompt_guard,
|
||||||
|
is_violation=True,
|
||||||
|
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||||
|
violation_return_message="Sorry, I cannot do this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.prompt_guard,
|
||||||
|
is_violation=False,
|
||||||
|
)
|
46
toolchain/safety/shields/shield_runner.py
Normal file
46
toolchain/safety/shields/shield_runner.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from models.llama3.datatypes import Message, Role
|
||||||
|
|
||||||
|
from .base import OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyException(Exception):
|
||||||
|
def __init__(self, response: ShieldResponse):
|
||||||
|
self.response = response
|
||||||
|
super().__init__(response.violation_return_message)
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldRunnerMixin:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_shields: List[ShieldBase] = None,
|
||||||
|
output_shields: List[ShieldBase] = None,
|
||||||
|
):
|
||||||
|
self.input_shields = input_shields
|
||||||
|
self.output_shields = output_shields
|
||||||
|
|
||||||
|
async def run_shields(
|
||||||
|
self, messages: List[Message], shields: List[ShieldBase]
|
||||||
|
) -> List[ShieldResponse]:
|
||||||
|
# some shields like llama-guard require the first message to be a user message
|
||||||
|
# since this might be a tool call, first role might not be user
|
||||||
|
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||||
|
# TODO(ashwin): we need to change the type of the message, this kind of modification
|
||||||
|
# is no longer appropriate
|
||||||
|
messages[0].role = Role.user.value
|
||||||
|
|
||||||
|
results = await asyncio.gather(*[s.run(messages) for s in shields])
|
||||||
|
for shield, r in zip(shields, results):
|
||||||
|
if r.is_violation:
|
||||||
|
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||||
|
raise SafetyException(r)
|
||||||
|
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||||
|
cprint(
|
||||||
|
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
54
toolchain/spec/generate.py
Normal file
54
toolchain/spec/generate.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from pyopenapi import Info, Options, Server, Specification
|
||||||
|
|
||||||
|
from models.llama3.datatypes import * # noqa: F403
|
||||||
|
from toolchain.dataset.api import * # noqa: F403
|
||||||
|
from toolchain.evaluations.api import * # noqa: F403
|
||||||
|
from toolchain.inference.api import * # noqa: F403
|
||||||
|
from toolchain.memory.api import * # noqa: F403
|
||||||
|
from toolchain.post_training.api import * # noqa: F403
|
||||||
|
from toolchain.reward_scoring.api import * # noqa: F403
|
||||||
|
from toolchain.synthetic_data_generation.api import * # noqa: F403
|
||||||
|
from agentic_system.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStackEndpoints(
|
||||||
|
ModelInference,
|
||||||
|
AgenticSystem,
|
||||||
|
RewardScoring,
|
||||||
|
SyntheticDataGeneration,
|
||||||
|
Datasets,
|
||||||
|
PostTraining,
|
||||||
|
MemoryBanks,
|
||||||
|
Evaluations,
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
now = str(datetime.now())
|
||||||
|
print(
|
||||||
|
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
|
||||||
|
)
|
||||||
|
spec = Specification(
|
||||||
|
LlamaStackEndpoints,
|
||||||
|
Options(
|
||||||
|
server=Server(url="http://any-hosted-llama-stack.com"),
|
||||||
|
info=Info(
|
||||||
|
title="[DRAFT] Llama Stack Specification",
|
||||||
|
version="0.0.1",
|
||||||
|
description="""This is the specification of the llama stack that provides
|
||||||
|
a set of endpoints and their corresponding interfaces that are tailored to
|
||||||
|
best leverage Llama Models. The specification is still in draft and subject to change.
|
||||||
|
Generated at """
|
||||||
|
+ now,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with open("openapi.yaml", "w", encoding="utf-8") as fp:
|
||||||
|
yaml.dump(spec.get_json(), fp, allow_unicode=True)
|
||||||
|
|
||||||
|
with open("openapi.html", "w") as fp:
|
||||||
|
spec.write_html(fp, pretty_print=True)
|
4584
toolchain/spec/openapi.html
Normal file
4584
toolchain/spec/openapi.html
Normal file
File diff suppressed because it is too large
Load diff
2894
toolchain/spec/openapi.yaml
Normal file
2894
toolchain/spec/openapi.yaml
Normal file
File diff suppressed because it is too large
Load diff
22
toolchain/spec/package.sh
Normal file
22
toolchain/spec/package.sh
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
TMPDIR=$(mktemp -d)
|
||||||
|
echo "Using temporary directory: $TMPDIR"
|
||||||
|
|
||||||
|
rootdir=$(git rev-parse --show-toplevel)
|
||||||
|
|
||||||
|
files_to_copy=("toolchain/spec/openapi*" "models/llama3/datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py")
|
||||||
|
for file in "${files_to_copy[@]}"; do
|
||||||
|
relpath="$file"
|
||||||
|
set -x
|
||||||
|
mkdir -p "$TMPDIR/$(dirname $relpath)"
|
||||||
|
eval cp "$rootdir/$relpath" "$TMPDIR/$(dirname $relpath)"
|
||||||
|
set +x
|
||||||
|
done
|
||||||
|
|
||||||
|
cd "$TMPDIR"
|
||||||
|
zip -r output.zip .
|
||||||
|
|
||||||
|
echo "Zip at: $TMPDIR/output.zip"
|
105
toolchain/spec/post_training_types.py
Normal file
105
toolchain/spec/post_training_types.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from models.llama3.datatypes import URL
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class TrainEvalDatasetColumnType(Enum):
|
||||||
|
dialog = "dialog"
|
||||||
|
text = "text"
|
||||||
|
media = "media"
|
||||||
|
number = "number"
|
||||||
|
json = "json"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainEvalDataset(BaseModel):
|
||||||
|
"""Dataset to be used for training or evaluating language models."""
|
||||||
|
|
||||||
|
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
||||||
|
|
||||||
|
columns: Dict[str, TrainEvalDatasetColumnType]
|
||||||
|
content_url: URL
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerType(Enum):
|
||||||
|
adam = "adam"
|
||||||
|
adamw = "adamw"
|
||||||
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OptimizerConfig(BaseModel):
|
||||||
|
optimizer_type: OptimizerType
|
||||||
|
lr: float
|
||||||
|
lr_min: float
|
||||||
|
weight_decay: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
n_epochs: int
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
n_iters: int
|
||||||
|
|
||||||
|
enable_activation_checkpointing: bool
|
||||||
|
memory_efficient_fsdp_wrap: bool
|
||||||
|
fsdp_cpu_offload: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FinetuningAlgorithm(Enum):
|
||||||
|
full = "full"
|
||||||
|
lora = "lora"
|
||||||
|
qlora = "qlora"
|
||||||
|
dora = "dora"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LoraFinetuningConfig(BaseModel):
|
||||||
|
lora_attn_modules: List[str]
|
||||||
|
apply_lora_to_mlp: bool
|
||||||
|
apply_lora_to_output: bool
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobLogStream(BaseModel):
|
||||||
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
log_lines: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class PostTrainingJobStatus(Enum):
|
||||||
|
running = "running"
|
||||||
|
completed = "completed"
|
||||||
|
failed = "failed"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
|
||||||
|
|
||||||
|
class RLHFAlgorithm(Enum):
|
||||||
|
dpo = "dpo"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DPOAlignmentConfig(BaseModel):
|
||||||
|
reward_scale: float
|
||||||
|
reward_clip: float
|
||||||
|
epsilon: float
|
||||||
|
gamma: float
|
5
toolchain/spec/run_openapi_generator.sh
Normal file
5
toolchain/spec/run_openapi_generator.sh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
PYTHONPATH=../../../oss-ops:../.. python3 -m toolchain.spec.generate
|
2
toolchain/synthetic_data_generation/api/__init__.py
Normal file
2
toolchain/synthetic_data_generation/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .datatypes import * # noqa: F401 F403
|
||||||
|
from .endpoints import * # noqa: F401 F403
|
12
toolchain/synthetic_data_generation/api/datatypes.py
Normal file
12
toolchain/synthetic_data_generation/api/datatypes.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class FilteringFunction(Enum):
|
||||||
|
"""The type of filtering function."""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
random = "random"
|
||||||
|
top_k = "top_k"
|
||||||
|
top_p = "top_p"
|
||||||
|
top_k_top_p = "top_k_top_p"
|
||||||
|
sigmoid = "sigmoid"
|
35
toolchain/synthetic_data_generation/api/endpoints.py
Normal file
35
toolchain/synthetic_data_generation/api/endpoints.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from pyopenapi import webmethod
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
from models.llama3.datatypes import * # noqa: F403
|
||||||
|
from toolchain.reward_scoring.api.datatypes import * # noqa: F403
|
||||||
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SyntheticDataGenerationRequest(BaseModel):
|
||||||
|
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||||
|
|
||||||
|
dialogs: List[Message]
|
||||||
|
filtering_function: FilteringFunction = FilteringFunction.none
|
||||||
|
model: Optional[RewardModel] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SyntheticDataGenerationResponse(BaseModel):
|
||||||
|
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
|
synthetic_data: List[ScoredDialogGenerations]
|
||||||
|
statistics: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDataGeneration(Protocol):
|
||||||
|
@webmethod(route="/synthetic_data_generation/generate")
|
||||||
|
def post_generate(
|
||||||
|
self,
|
||||||
|
request: SyntheticDataGenerationRequest,
|
||||||
|
) -> Union[SyntheticDataGenerationResponse]: ...
|
55
toolchain/utils.py
Normal file
55
toolchain/utils.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import getpass
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from hydra import compose, initialize, MissingConfigException
|
||||||
|
from hydra.core.global_hydra import GlobalHydra
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
def get_root_directory():
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
while os.path.isfile(os.path.join(current_dir, "__init__.py")):
|
||||||
|
current_dir = os.path.dirname(current_dir)
|
||||||
|
|
||||||
|
return current_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_dir():
|
||||||
|
return os.path.join(get_root_directory(), "toolchain", "configs")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
|
||||||
|
# Configs can be
|
||||||
|
# 1. relative paths in {config_dir}/
|
||||||
|
# 2. or default to file {config_dir}/{user}.yaml
|
||||||
|
# 3. or ultimate default to {config_dir}/default.yaml
|
||||||
|
|
||||||
|
# Get the relative path from the current file to the config directory
|
||||||
|
current_file_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
relative_path = os.path.relpath(config_dir, current_file_directory)
|
||||||
|
|
||||||
|
GlobalHydra.instance().clear()
|
||||||
|
initialize(config_path=relative_path)
|
||||||
|
|
||||||
|
if config_path is None:
|
||||||
|
try:
|
||||||
|
user = getpass.getuser()
|
||||||
|
config_name = user
|
||||||
|
except MissingConfigException:
|
||||||
|
print(f"No user-specific {user}.yaml, using default")
|
||||||
|
config_name = "default"
|
||||||
|
else:
|
||||||
|
config_name = config_path
|
||||||
|
|
||||||
|
config_abs_path = os.path.abspath(os.path.join(config_dir, f"{config_name}.yaml"))
|
||||||
|
print(f"Loading config from : {config_abs_path}")
|
||||||
|
config = compose(config_name=config_name)
|
||||||
|
|
||||||
|
print("Yaml config:")
|
||||||
|
print("------------------------")
|
||||||
|
print(OmegaConf.to_yaml(config, resolve=True))
|
||||||
|
print("------------------------")
|
||||||
|
|
||||||
|
return config
|
Loading…
Add table
Add a link
Reference in a new issue