From 95781ec85dfc18e12ee503bd7bcbf2cd94133544 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 19 Jul 2024 12:30:35 -0700 Subject: [PATCH] Add toolchain from agentic system here --- toolchain/__init__.py | 0 toolchain/cli/__init__.py | 0 toolchain/cli/download.py | 93 + toolchain/cli/inference/inference.py | 28 + toolchain/cli/inference/start.py | 53 + toolchain/cli/llama.py | 40 + toolchain/cli/subcommand.py | 12 + toolchain/common/deployment_types.py | 25 + toolchain/common/training_types.py | 7 + toolchain/configs/ashwin.yaml | 9 + toolchain/configs/chrisluc.yaml | 9 + toolchain/configs/default.yaml | 9 + toolchain/configs/hjshah.yaml | 9 + toolchain/configs/long_seqlen.yaml | 9 + toolchain/dataset/api/__init__.py | 2 + toolchain/dataset/api/datatypes.py | 27 + toolchain/dataset/api/endpoints.py | 36 + toolchain/evaluations/api/__init__.py | 2 + toolchain/evaluations/api/datatypes.py | 29 + toolchain/evaluations/api/endpoints.py | 93 + toolchain/inference/__init__.py | 0 toolchain/inference/api/__init__.py | 2 + toolchain/inference/api/config.py | 146 + toolchain/inference/api/datatypes.py | 68 + toolchain/inference/api/endpoints.py | 117 + toolchain/inference/api_instance.py | 12 + toolchain/inference/client.py | 73 + toolchain/inference/generation.py | 298 ++ toolchain/inference/inference.py | 173 + toolchain/inference/model_parallel.py | 100 + toolchain/inference/parallel_utils.py | 259 + .../inference/quantization/build_conda.sh | 45 + toolchain/inference/quantization/fp8_impls.py | 165 + .../quantization/fp8_requirements.txt | 5 + .../inference/quantization/generation.py | 455 ++ toolchain/inference/quantization/model.py | 355 ++ .../quantization/quantize_checkpoint.py | 155 + .../quantization/run_quantize_checkpoint.sh | 25 + toolchain/inference/quantization/test_fp8.py | 102 + toolchain/inference/server.py | 117 + toolchain/memory/api/__init__.py | 2 + toolchain/memory/api/datatypes.py | 19 + toolchain/memory/api/endpoints.py | 55 + toolchain/post_training/api/__init__.py | 2 + toolchain/post_training/api/datatypes.py | 88 + toolchain/post_training/api/endpoints.py | 123 + toolchain/reward_scoring/api/__init__.py | 2 + toolchain/reward_scoring/api/datatypes.py | 25 + toolchain/reward_scoring/api/endpoints.py | 27 + toolchain/safety/__init__.py | 0 toolchain/safety/api/__init__.py | 0 toolchain/safety/api/config.py | 19 + toolchain/safety/api/datatypes.py | 53 + toolchain/safety/shields/__init__.py | 25 + toolchain/safety/shields/base.py | 65 + toolchain/safety/shields/code_scanner.py | 28 + toolchain/safety/shields/contrib/__init__.py | 0 .../shields/contrib/third_party_shield.py | 28 + toolchain/safety/shields/llama_guard.py | 248 + toolchain/safety/shields/prompt_guard.py | 112 + toolchain/safety/shields/shield_runner.py | 46 + toolchain/spec/generate.py | 54 + toolchain/spec/openapi.html | 4584 +++++++++++++++++ toolchain/spec/openapi.yaml | 2894 +++++++++++ toolchain/spec/package.sh | 22 + toolchain/spec/post_training_types.py | 105 + toolchain/spec/run_openapi_generator.sh | 5 + .../synthetic_data_generation/api/__init__.py | 2 + .../api/datatypes.py | 12 + .../api/endpoints.py | 35 + toolchain/utils.py | 55 + 71 files changed, 11899 insertions(+) create mode 100644 toolchain/__init__.py create mode 100644 toolchain/cli/__init__.py create mode 100644 toolchain/cli/download.py create mode 100644 toolchain/cli/inference/inference.py create mode 100644 toolchain/cli/inference/start.py create mode 100644 toolchain/cli/llama.py create mode 100644 toolchain/cli/subcommand.py create mode 100644 toolchain/common/deployment_types.py create mode 100644 toolchain/common/training_types.py create mode 100644 toolchain/configs/ashwin.yaml create mode 100644 toolchain/configs/chrisluc.yaml create mode 100644 toolchain/configs/default.yaml create mode 100644 toolchain/configs/hjshah.yaml create mode 100644 toolchain/configs/long_seqlen.yaml create mode 100644 toolchain/dataset/api/__init__.py create mode 100644 toolchain/dataset/api/datatypes.py create mode 100644 toolchain/dataset/api/endpoints.py create mode 100644 toolchain/evaluations/api/__init__.py create mode 100644 toolchain/evaluations/api/datatypes.py create mode 100644 toolchain/evaluations/api/endpoints.py create mode 100644 toolchain/inference/__init__.py create mode 100644 toolchain/inference/api/__init__.py create mode 100644 toolchain/inference/api/config.py create mode 100644 toolchain/inference/api/datatypes.py create mode 100644 toolchain/inference/api/endpoints.py create mode 100644 toolchain/inference/api_instance.py create mode 100644 toolchain/inference/client.py create mode 100644 toolchain/inference/generation.py create mode 100644 toolchain/inference/inference.py create mode 100644 toolchain/inference/model_parallel.py create mode 100644 toolchain/inference/parallel_utils.py create mode 100644 toolchain/inference/quantization/build_conda.sh create mode 100644 toolchain/inference/quantization/fp8_impls.py create mode 100644 toolchain/inference/quantization/fp8_requirements.txt create mode 100644 toolchain/inference/quantization/generation.py create mode 100644 toolchain/inference/quantization/model.py create mode 100644 toolchain/inference/quantization/quantize_checkpoint.py create mode 100755 toolchain/inference/quantization/run_quantize_checkpoint.sh create mode 100644 toolchain/inference/quantization/test_fp8.py create mode 100644 toolchain/inference/server.py create mode 100644 toolchain/memory/api/__init__.py create mode 100644 toolchain/memory/api/datatypes.py create mode 100644 toolchain/memory/api/endpoints.py create mode 100644 toolchain/post_training/api/__init__.py create mode 100644 toolchain/post_training/api/datatypes.py create mode 100644 toolchain/post_training/api/endpoints.py create mode 100644 toolchain/reward_scoring/api/__init__.py create mode 100644 toolchain/reward_scoring/api/datatypes.py create mode 100644 toolchain/reward_scoring/api/endpoints.py create mode 100644 toolchain/safety/__init__.py create mode 100644 toolchain/safety/api/__init__.py create mode 100644 toolchain/safety/api/config.py create mode 100644 toolchain/safety/api/datatypes.py create mode 100644 toolchain/safety/shields/__init__.py create mode 100644 toolchain/safety/shields/base.py create mode 100644 toolchain/safety/shields/code_scanner.py create mode 100644 toolchain/safety/shields/contrib/__init__.py create mode 100644 toolchain/safety/shields/contrib/third_party_shield.py create mode 100644 toolchain/safety/shields/llama_guard.py create mode 100644 toolchain/safety/shields/prompt_guard.py create mode 100644 toolchain/safety/shields/shield_runner.py create mode 100644 toolchain/spec/generate.py create mode 100644 toolchain/spec/openapi.html create mode 100644 toolchain/spec/openapi.yaml create mode 100644 toolchain/spec/package.sh create mode 100644 toolchain/spec/post_training_types.py create mode 100644 toolchain/spec/run_openapi_generator.sh create mode 100644 toolchain/synthetic_data_generation/api/__init__.py create mode 100644 toolchain/synthetic_data_generation/api/datatypes.py create mode 100644 toolchain/synthetic_data_generation/api/endpoints.py create mode 100644 toolchain/utils.py diff --git a/toolchain/__init__.py b/toolchain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/toolchain/cli/__init__.py b/toolchain/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/toolchain/cli/download.py b/toolchain/cli/download.py new file mode 100644 index 000000000..edb1eb3a3 --- /dev/null +++ b/toolchain/cli/download.py @@ -0,0 +1,93 @@ +import argparse +import os +import textwrap +from pathlib import Path + +from huggingface_hub import snapshot_download +from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + +from toolchain.cli.subcommand import Subcommand + + +DEFAULT_OUTPUT_DIR = "/tmp/llama_toolchain_cache/" + + +class Download(Subcommand): + """Llama cli for downloading llama toolchain assets""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "download", + prog="llama download", + description="Download a model from the Hugging Face Hub", + epilog=textwrap.dedent( + """\ + # Here are some examples on how to use this command: + + llama download --repo-id meta-llama/Llama-2-7b-hf --hf-token + llama download --repo-id meta-llama/Llama-2-7b-hf --output-dir /data/my_custom_dir --hf-token + HF_TOKEN= llama download --repo-id meta-llama/Llama-2-7b-hf + + The output directory will be used to load models and tokenizers for inference. + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_download_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "repo_id", + type=str, + help="Name of the repository on Hugging Face Hub eg. llhf/Meta-Llama-3.1-70B-Instruct", + ) + self.parser.add_argument( + "--output-dir", + type=Path, + required=False, + default=None, + help=f"Directory in which to save the model. Defaults to `{DEFAULT_OUTPUT_DIR}`.", + ) + self.parser.add_argument( + "--hf-token", + type=str, + required=False, + default=os.getenv("HF_TOKEN", None), + help="Hugging Face API token. Needed for gated models like Llama2. Will also try to read environment variable `HF_TOKEN` as default.", + ) + + def _run_download_cmd(self, args: argparse.Namespace): + model_name = args.repo_id.split("/")[-1] + + os.makedirs(DEFAULT_OUTPUT_DIR, exist_ok=True) + output_dir = args.output_dir + if output_dir is None: + model_name = args.repo_id.split("/")[-1] + output_dir = Path(DEFAULT_OUTPUT_DIR) / model_name + + try: + true_output_dir = snapshot_download( + args.repo_id, + local_dir=output_dir, + # "auto" will download to cache_dir and symlink files to local_dir + # avoiding unnecessary duplicate copies + local_dir_use_symlinks="auto", + token=args.hf_token, + ) + except GatedRepoError: + self.parser.error( + "It looks like you are trying to access a gated repository. Please ensure you " + "have access to the repository and have provided the proper Hugging Face API token " + "using the option `--hf-token` or by running `huggingface-cli login`." + "You can find your token by visiting https://huggingface.co/settings/tokens" + ) + except RepositoryNotFoundError: + self.parser.error( + f"Repository '{args.repo_id}' not found on the Hugging Face Hub." + ) + except Exception as e: + self.parser.error(e) + + print(f"Successfully downloaded model to {true_output_dir}") diff --git a/toolchain/cli/inference/inference.py b/toolchain/cli/inference/inference.py new file mode 100644 index 000000000..d8cf7b1ba --- /dev/null +++ b/toolchain/cli/inference/inference.py @@ -0,0 +1,28 @@ +import argparse +import textwrap + +from toolchain.cli.inference.start import InferenceStart +from toolchain.cli.subcommand import Subcommand + + +class InferenceParser(Subcommand): + """Llama cli for inference apis""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "inference", + prog="llama inference", + description="Run inference on a llama model", + epilog=textwrap.dedent( + """ + Example: + llama inference start + """ + ), + ) + + subparsers = self.parser.add_subparsers(title="inference_subcommands") + + # Add sub-commandsa + InferenceStart.create(subparsers) diff --git a/toolchain/cli/inference/start.py b/toolchain/cli/inference/start.py new file mode 100644 index 000000000..ab447b644 --- /dev/null +++ b/toolchain/cli/inference/start.py @@ -0,0 +1,53 @@ +import argparse +import textwrap + +from toolchain.cli.subcommand import Subcommand + +from toolchain.inference.server import main as inference_server_init + + +class InferenceStart(Subcommand): + """Llama Inference cli for starting inference server""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "start", + prog="llama inference start", + description="Start an inference server", + epilog=textwrap.dedent( + """ + Example: + llama inference start + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_inference_start_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "--port", + type=int, + help="Port to run the server on. Defaults to 5000", + default=5000, + ) + self.parser.add_argument( + "--disable-ipv6", + action="store_true", + help="Disable IPv6 support", + default=False, + ) + self.parser.add_argument( + "--config", + type=str, + help="Path to config file", + ) + + def _run_inference_start_cmd(self, args: argparse.Namespace) -> None: + inference_server_init( + config_path=args.config, + port=args.port, + disable_ipv6=args.disable_ipv6, + ) diff --git a/toolchain/cli/llama.py b/toolchain/cli/llama.py new file mode 100644 index 000000000..860451cff --- /dev/null +++ b/toolchain/cli/llama.py @@ -0,0 +1,40 @@ +import argparse + +from toolchain.cli.download import Download +from toolchain.cli.inference.inference import InferenceParser + + +class LlamaCLIParser: + """Defines CLI parser for Llama CLI""" + + def __init__(self): + self.parser = argparse.ArgumentParser( + prog="llama", + description="Welcome to the LLama toolchain cli", + add_help=True, + ) + + # Default command is to print help + self.parser.set_defaults(func=lambda args: self.parser.print_help()) + + subparsers = self.parser.add_subparsers(title="subcommands") + + # Add sub-commands + Download.create(subparsers) + InferenceParser.create(subparsers) + + def parse_args(self) -> argparse.Namespace: + return self.parser.parse_args() + + def run(self, args: argparse.Namespace) -> None: + args.func(args) + + +def main(): + parser = LlamaCLIParser() + args = parser.parse_args() + parser.run(args) + + +if __name__ == "__main__": + main() diff --git a/toolchain/cli/subcommand.py b/toolchain/cli/subcommand.py new file mode 100644 index 000000000..10bb6667d --- /dev/null +++ b/toolchain/cli/subcommand.py @@ -0,0 +1,12 @@ +class Subcommand: + """All llama cli subcommands must inherit from this class""" + + def __init__(self, *args, **kwargs): + pass + + @classmethod + def create(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def _add_arguments(self): + pass diff --git a/toolchain/common/deployment_types.py b/toolchain/common/deployment_types.py new file mode 100644 index 000000000..88831f41c --- /dev/null +++ b/toolchain/common/deployment_types.py @@ -0,0 +1,25 @@ +from enum import Enum +from typing import Dict, Optional + +from models.llama3.datatypes import URL + +from pydantic import BaseModel + +from strong_typing.schema import json_schema_type + + +@json_schema_type +class RestAPIMethod(Enum): + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + + +@json_schema_type +class RestAPIExecutionConfig(BaseModel): + url: URL + method: RestAPIMethod + params: Optional[Dict[str, str]] = None + headers: Optional[Dict[str, str]] = None + body: Optional[Dict[str, str]] = None diff --git a/toolchain/common/training_types.py b/toolchain/common/training_types.py new file mode 100644 index 000000000..d5e756bba --- /dev/null +++ b/toolchain/common/training_types.py @@ -0,0 +1,7 @@ +from models.llama3.datatypes import URL +from pydantic import BaseModel + + +class Checkpoint(BaseModel): + iters: int + path: URL diff --git a/toolchain/configs/ashwin.yaml b/toolchain/configs/ashwin.yaml new file mode 100644 index 000000000..c2f1ca245 --- /dev/null +++ b/toolchain/configs/ashwin.yaml @@ -0,0 +1,9 @@ +model_inference_config: + impl_type: "inline" + inline_config: + checkpoint_type: "pytorch" + checkpoint_dir: /home/ashwin/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000 + tokenizer_path: /home/ashwin/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model + model_parallel_size: 1 + max_seq_len: 2048 + max_batch_size: 1 diff --git a/toolchain/configs/chrisluc.yaml b/toolchain/configs/chrisluc.yaml new file mode 100644 index 000000000..be51a534c --- /dev/null +++ b/toolchain/configs/chrisluc.yaml @@ -0,0 +1,9 @@ +model_inference_config: + impl_type: "inline" + inline_config: + checkpoint_type: "pytorch" + checkpoint_dir: /home/chrisluc/models/Meta-Llama-3.1-8B-Instruct-20240710150000 + tokenizer_path: /home/chrisluc/models/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model + model_parallel_size: 1 + max_seq_len: 2048 + max_batch_size: 1 diff --git a/toolchain/configs/default.yaml b/toolchain/configs/default.yaml new file mode 100644 index 000000000..642a55f22 --- /dev/null +++ b/toolchain/configs/default.yaml @@ -0,0 +1,9 @@ +model_inference_config: + impl_type: "inline" + inline_config: + checkpoint_type: "pytorch" + checkpoint_dir: /home/dalton/models/Meta-Llama-3.1-8B-Instruct-20240710150000 + tokenizer_path: /home/dalton/models/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model + model_parallel_size: 1 + max_seq_len: 2048 + max_batch_size: 1 diff --git a/toolchain/configs/hjshah.yaml b/toolchain/configs/hjshah.yaml new file mode 100644 index 000000000..98e2660ea --- /dev/null +++ b/toolchain/configs/hjshah.yaml @@ -0,0 +1,9 @@ +model_inference_config: + impl_type: "inline" + inline_config: + checkpoint_type: "pytorch" + checkpoint_dir: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000 + tokenizer_path: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model + model_parallel_size: 1 + max_seq_len: 2048 + max_batch_size: 1 diff --git a/toolchain/configs/long_seqlen.yaml b/toolchain/configs/long_seqlen.yaml new file mode 100644 index 000000000..e137d0273 --- /dev/null +++ b/toolchain/configs/long_seqlen.yaml @@ -0,0 +1,9 @@ +model_inference_config: + impl_type: "inline" + inline_config: + checkpoint_type: "pytorch" + checkpoint_dir: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000 + tokenizer_path: /home/hjshah/local/checkpoints/Meta-Llama-3.1-8B-Instruct-20240710150000/tokenizer.model + model_parallel_size: 1 + max_seq_len: 8192 + max_batch_size: 1 diff --git a/toolchain/dataset/api/__init__.py b/toolchain/dataset/api/__init__.py new file mode 100644 index 000000000..38413ff60 --- /dev/null +++ b/toolchain/dataset/api/__init__.py @@ -0,0 +1,2 @@ +from .datatypes import * # noqa: F401 F403 +from .endpoints import * # noqa: F401 F403 diff --git a/toolchain/dataset/api/datatypes.py b/toolchain/dataset/api/datatypes.py new file mode 100644 index 000000000..4fed2a0a3 --- /dev/null +++ b/toolchain/dataset/api/datatypes.py @@ -0,0 +1,27 @@ +from enum import Enum +from typing import Any, Dict, Optional + +from pydantic import BaseModel + +from strong_typing.schema import json_schema_type +from models.llama3.datatypes import URL + + +@json_schema_type +class TrainEvalDatasetColumnType(Enum): + dialog = "dialog" + text = "text" + media = "media" + number = "number" + json = "json" + + +@json_schema_type +class TrainEvalDataset(BaseModel): + """Dataset to be used for training or evaluating language models.""" + + # TODO(ashwin): figure out if we need to add an enum for a "dataset type" + + columns: Dict[str, TrainEvalDatasetColumnType] + content_url: URL + metadata: Optional[Dict[str, Any]] = None diff --git a/toolchain/dataset/api/endpoints.py b/toolchain/dataset/api/endpoints.py new file mode 100644 index 000000000..023f91259 --- /dev/null +++ b/toolchain/dataset/api/endpoints.py @@ -0,0 +1,36 @@ +from typing import Protocol + +from pydantic import BaseModel + +from pyopenapi import webmethod +from strong_typing.schema import json_schema_type + +from .datatypes import * # noqa: F403 + + +@json_schema_type +class CreateDatasetRequest(BaseModel): + """Request to create a dataset.""" + + uuid: str + dataset: TrainEvalDataset + + +class Datasets(Protocol): + @webmethod(route="/datasets/create") + def create_dataset( + self, + request: CreateDatasetRequest, + ) -> None: ... + + @webmethod(route="/datasets/get") + def get_dataset( + self, + dataset_uuid: str, + ) -> TrainEvalDataset: ... + + @webmethod(route="/datasets/delete") + def delete_dataset( + self, + dataset_uuid: str, + ) -> None: ... diff --git a/toolchain/evaluations/api/__init__.py b/toolchain/evaluations/api/__init__.py new file mode 100644 index 000000000..38413ff60 --- /dev/null +++ b/toolchain/evaluations/api/__init__.py @@ -0,0 +1,2 @@ +from .datatypes import * # noqa: F401 F403 +from .endpoints import * # noqa: F401 F403 diff --git a/toolchain/evaluations/api/datatypes.py b/toolchain/evaluations/api/datatypes.py new file mode 100644 index 000000000..692664846 --- /dev/null +++ b/toolchain/evaluations/api/datatypes.py @@ -0,0 +1,29 @@ +from enum import Enum + +from pydantic import BaseModel + + +class TextGenerationMetric(Enum): + perplexity = "perplexity" + rouge = "rouge" + bleu = "bleu" + + +class QuestionAnsweringMetric(Enum): + em = "em" + f1 = "f1" + + +class SummarizationMetric(Enum): + rouge = "rouge" + bleu = "bleu" + + +class EvaluationJob(BaseModel): + + job_uuid: str + + +class EvaluationJobLogStream(BaseModel): + + job_uuid: str diff --git a/toolchain/evaluations/api/endpoints.py b/toolchain/evaluations/api/endpoints.py new file mode 100644 index 000000000..7fef5d9da --- /dev/null +++ b/toolchain/evaluations/api/endpoints.py @@ -0,0 +1,93 @@ +from typing import List, Protocol + +from pydantic import BaseModel + +from pyopenapi import webmethod + +from models.llama3.datatypes import * # noqa: F403 +from .datatypes import * # noqa: F403 +from toolchain.dataset.api.datatypes import * # noqa: F403 +from toolchain.common.training_types import * # noqa: F403 + + +class EvaluateTaskRequestCommon(BaseModel): + job_uuid: str + dataset: TrainEvalDataset + + checkpoint: Checkpoint + + # generation params + sampling_params: SamplingParams = SamplingParams() + + +@json_schema_type +class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon): + """Request to evaluate text generation.""" + + metrics: List[TextGenerationMetric] + + +@json_schema_type +class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon): + """Request to evaluate question answering.""" + + metrics: List[QuestionAnsweringMetric] + + +@json_schema_type +class EvaluateSummarizationRequest(EvaluateTaskRequestCommon): + """Request to evaluate summarization.""" + + metrics: List[SummarizationMetric] + + +class EvaluationJobStatusResponse(BaseModel): + + job_uuid: str + + +@json_schema_type +class EvaluationJobArtifactsResponse(BaseModel): + """Artifacts of a evaluation job.""" + + job_uuid: str + + +class Evaluations(Protocol): + @webmethod(route="/evaluate/text_generation/") + def post_evaluate_text_generation( + self, + request: EvaluateTextGenerationRequest, + ) -> EvaluationJob: ... + + @webmethod(route="/evaluate/question_answering/") + def post_evaluate_question_answering( + self, + request: EvaluateQuestionAnsweringRequest, + ) -> EvaluationJob: ... + + @webmethod(route="/evaluate/summarization/") + def post_evaluate_summarization( + self, + request: EvaluateSummarizationRequest, + ) -> EvaluationJob: ... + + @webmethod(route="/evaluate/jobs") + def get_evaluation_jobs(self) -> List[EvaluationJob]: ... + + @webmethod(route="/evaluate/job/status") + def get_evaluation_job_status( + self, job_uuid: str + ) -> EvaluationJobStatusResponse: ... + + # sends SSE stream of logs + @webmethod(route="/evaluate/job/logs") + def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... + + @webmethod(route="/evaluate/job/cancel") + def cancel_evaluation_job(self, job_uuid: str) -> None: ... + + @webmethod(route="/evaluate/job/artifacts") + def get_evaluation_job_artifacts( + self, job_uuid: str + ) -> EvaluationJobArtifactsResponse: ... diff --git a/toolchain/inference/__init__.py b/toolchain/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/toolchain/inference/api/__init__.py b/toolchain/inference/api/__init__.py new file mode 100644 index 000000000..38413ff60 --- /dev/null +++ b/toolchain/inference/api/__init__.py @@ -0,0 +1,2 @@ +from .datatypes import * # noqa: F401 F403 +from .endpoints import * # noqa: F401 F403 diff --git a/toolchain/inference/api/config.py b/toolchain/inference/api/config.py new file mode 100644 index 000000000..2340e2d32 --- /dev/null +++ b/toolchain/inference/api/config.py @@ -0,0 +1,146 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Literal, Optional, Union + +from hydra.core.config_store import ConfigStore + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + + +@dataclass +class GeneratorArgs: + ckpt_dir: str + tokenizer_path: str + model_parallel_size: Optional[int] = None + max_seq_len: int = 2048 + max_batch_size: int = 4 + + +class ImplType(Enum): + inline = "inline" + remote = "remote" + + +class CheckpointType(Enum): + pytorch = "pytorch" + huggingface = "huggingface" + + +class PytorchCheckpoint(BaseModel): + checkpoint_type: Literal[CheckpointType.pytorch.value] = ( + CheckpointType.pytorch.value + ) + checkpoint_dir: str + tokenizer_path: str + model_parallel_size: int + + +class HuggingFaceCheckpoint(BaseModel): + checkpoint_type: Literal[CheckpointType.huggingface.value] = ( + CheckpointType.huggingface.value + ) + repo_id: str # or model_name ? + model_parallel_size: int + + +class ModelCheckpointConfig(BaseModel): + checkpoint: Annotated[ + Union[PytorchCheckpoint, HuggingFaceCheckpoint], + Field(discriminator="checkpoint_type"), + ] + + +# NOTE: this same config will be used when instantiating an inference server naturally +class InlineImplConfig(BaseModel): + impl_type: Literal[ImplType.inline.value] = ImplType.inline.value + checkpoint_config: ModelCheckpointConfig + max_seq_len: int + max_batch_size: int = 1 + + +class RemoteImplConfig(BaseModel): + impl_type: Literal[ImplType.remote.value] = ImplType.remote.value + url: str = Field(..., description="The URL of the remote module") + + +class ModelInferenceConfig(BaseModel): + impl_config: Annotated[ + Union[InlineImplConfig, RemoteImplConfig], + Field(discriminator="impl_type"), + ] + + +# Hydra does not like unions of containers and +# Pydantic does not like Literals +# Adding a simple dataclass with custom coversion +# to config classes + + +@dataclass +class InlineImplHydraConfig: + checkpoint_type: str # "pytorch" / "HF" + # pytorch checkpoint required args + checkpoint_dir: str + tokenizer_path: str + model_parallel_size: int + max_seq_len: int + max_batch_size: int = 1 + # TODO: huggingface checkpoint required args + + def convert_to_inline_impl_config(self): + if self.checkpoint_type == "pytorch": + return InlineImplConfig( + checkpoint_config=ModelCheckpointConfig( + checkpoint=PytorchCheckpoint( + checkpoint_type=CheckpointType.pytorch.value, + checkpoint_dir=self.checkpoint_dir, + tokenizer_path=self.tokenizer_path, + model_parallel_size=self.model_parallel_size, + ) + ), + max_seq_len=self.max_seq_len, + max_batch_size=self.max_batch_size, + ) + else: + raise NotImplementedError("HF Checkpoint not supported yet") + + +@dataclass +class RemoteImplHydraConfig: + url: str + + def convert_to_remote_impl_config(self): + return RemoteImplConfig( + url=self.url, + ) + + +@dataclass +class ModelInferenceHydraConfig: + impl_type: str + inline_config: Optional[InlineImplHydraConfig] = None + remote_config: Optional[RemoteImplHydraConfig] = None + + def __post_init__(self): + assert self.impl_type in ["inline", "remote"] + if self.impl_type == "inline": + assert self.inline_config is not None + if self.impl_type == "remote": + assert self.remote_config is not None + + def convert_to_model_inferene_config(self): + if self.impl_type == "inline": + inline_config = InlineImplHydraConfig(**self.inline_config) + return ModelInferenceConfig( + impl_config=inline_config.convert_to_inline_impl_config() + ) + elif self.impl_type == "remote": + remote_config = RemoteImplHydraConfig(**self.remote_config) + return ModelInferenceConfig( + impl_config=remote_config.convert_to_remote_impl_config() + ) + + +cs = ConfigStore.instance() +cs.store(name="model_inference_config", node=ModelInferenceHydraConfig) diff --git a/toolchain/inference/api/datatypes.py b/toolchain/inference/api/datatypes.py new file mode 100644 index 000000000..ef54d869d --- /dev/null +++ b/toolchain/inference/api/datatypes.py @@ -0,0 +1,68 @@ +from enum import Enum +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from strong_typing.schema import json_schema_type +from typing_extensions import Annotated + +from models.llama3.datatypes import * # noqa: F403 + + +class LogProbConfig(BaseModel): + top_k: Optional[int] = 0 + + +@json_schema_type +class QuantizationType(Enum): + bf16 = "bf16" + fp8 = "fp8" + + +@json_schema_type +class Fp8QuantizationConfig(BaseModel): + quantization_type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value + + +@json_schema_type +class Bf16QuantizationConfig(BaseModel): + quantization_type: Literal[QuantizationType.bf16.value] = ( + QuantizationType.bf16.value + ) + + +QuantizationConfig = Annotated[ + Union[Bf16QuantizationConfig, Fp8QuantizationConfig], + Field(discriminator="quantization_type"), +] + + +@json_schema_type +class ChatCompletionResponseEventType(Enum): + start = "start" + complete = "complete" + progress = "progress" + + +@json_schema_type +class ToolCallParseStatus(Enum): + started = "started" + in_progress = "in_progress" + failure = "failure" + success = "success" + + +@json_schema_type +class ToolCallDelta(BaseModel): + content: Union[str, ToolCall] + parse_status: ToolCallParseStatus + + +@json_schema_type +class ChatCompletionResponseEvent(BaseModel): + """Chat completion response event.""" + + event_type: ChatCompletionResponseEventType + delta: Union[str, ToolCallDelta] + logprobs: Optional[List[TokenLogProbs]] = None + stop_reason: Optional[StopReason] = None diff --git a/toolchain/inference/api/endpoints.py b/toolchain/inference/api/endpoints.py new file mode 100644 index 000000000..5b262a99c --- /dev/null +++ b/toolchain/inference/api/endpoints.py @@ -0,0 +1,117 @@ +from .datatypes import * # noqa: F403 +from typing import Optional, Protocol + +# this dependency is annoying and we need a forked up version anyway +from pyopenapi import webmethod + + +@json_schema_type +class CompletionRequest(BaseModel): + model: PretrainedModel + content: InterleavedTextAttachment + sampling_params: Optional[SamplingParams] = SamplingParams() + + stream: Optional[bool] = False + logprobs: Optional[LogProbConfig] = None + quantization_config: Optional[QuantizationConfig] = None + + +@json_schema_type +class CompletionResponse(BaseModel): + completion_message: CompletionMessage + logprobs: Optional[List[TokenLogProbs]] = None + + +@json_schema_type +class CompletionResponseStreamChunk(BaseModel): + """streamed completion response.""" + + delta: str + stop_reason: Optional[StopReason] = None + logprobs: Optional[List[TokenLogProbs]] = None + + +@json_schema_type +class BatchCompletionRequest(BaseModel): + model: PretrainedModel + content_batch: List[InterleavedTextAttachment] + sampling_params: Optional[SamplingParams] = SamplingParams() + logprobs: Optional[LogProbConfig] = None + quantization_config: Optional[QuantizationConfig] = None + + +@json_schema_type +class BatchCompletionResponse(BaseModel): + completion_message_batch: List[CompletionMessage] + + +@json_schema_type +class ChatCompletionRequest(BaseModel): + model: InstructModel + messages: List[Message] + sampling_params: Optional[SamplingParams] = SamplingParams() + + # zero-shot tool definitions as input to the model + available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + + stream: Optional[bool] = False + logprobs: Optional[LogProbConfig] = None + quantization_config: Optional[QuantizationConfig] = None + + +@json_schema_type +class ChatCompletionResponseStreamChunk(BaseModel): + """SSE-stream of these events.""" + + event: ChatCompletionResponseEvent + + +@json_schema_type +class ChatCompletionResponse(BaseModel): + completion_message: CompletionMessage + logprobs: Optional[List[TokenLogProbs]] = None + + +@json_schema_type +class BatchChatCompletionRequest(BaseModel): + model: InstructModel + messages_batch: List[List[Message]] + sampling_params: Optional[SamplingParams] = SamplingParams() + + # zero-shot tool definitions as input to the model + available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + + logprobs: Optional[LogProbConfig] = None + quantization_config: Optional[QuantizationConfig] = None + + +@json_schema_type +class BatchChatCompletionResponse(BaseModel): + completion_message_batch: List[CompletionMessage] + + +class ModelInference(Protocol): + + @webmethod(route="/inference/completion") + async def completion( + self, + request: CompletionRequest, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... + + @webmethod(route="/inference/chat_completion") + async def chat_completion( + self, + request: ChatCompletionRequest, + ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... + + @webmethod(route="/inference/batch_completion") + async def batch_completion( + self, + request: BatchCompletionRequest, + ) -> List[CompletionResponse]: ... + + @webmethod(route="/inference/batch_chat_completion") + async def batch_chat_completion( + self, + request: BatchChatCompletionRequest, + ) -> List[ChatCompletionResponse]: ... diff --git a/toolchain/inference/api_instance.py b/toolchain/inference/api_instance.py new file mode 100644 index 000000000..6110fd257 --- /dev/null +++ b/toolchain/inference/api_instance.py @@ -0,0 +1,12 @@ +from .api.config import ImplType, ModelInferenceConfig + + +async def get_inference_api_instance(config: ModelInferenceConfig): + if config.impl_config.impl_type == ImplType.inline.value: + from .inference import ModelInferenceImpl + + return ModelInferenceImpl(config.impl_config) + + from .client import ModelInferenceClient + + return ModelInferenceClient(config.impl_config.url) diff --git a/toolchain/inference/client.py b/toolchain/inference/client.py new file mode 100644 index 000000000..0cb14e4c7 --- /dev/null +++ b/toolchain/inference/client.py @@ -0,0 +1,73 @@ +import asyncio +import json +from typing import AsyncGenerator + +import fire +import httpx + +from .api.endpoints import ( + ChatCompletionRequest, + ChatCompletionResponseStreamChunk, + CompletionRequest, + InstructModel, + ModelInference, +) + + +class ModelInferenceClient(ModelInference): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def completion(self, request: CompletionRequest) -> AsyncGenerator: + raise NotImplementedError() + + async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + f"{self.base_url}/inference/chat_completion", + data=request.json(), + headers={"Content-Type": "application/json"}, + timeout=20, + ) as response: + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[len("data: ") :] + try: + yield ChatCompletionResponseStreamChunk(**json.loads(data)) + except Exception as e: + print(data) + print(f"Error with parsing or validation: {e}") + + +async def run_main(host: str, port: int): + client = ModelInferenceClient(f"http://{host}:{port}") + + message = UserMessage(content="hello world, help me out here") + req = ChatCompletionRequest( + model=InstructModel.llama3_70b_chat, + messages=[message], + stream=True, + ) + async for event in client.chat_completion( + ChatCompletionRequest( + model=InstructModel.llama3_70b_chat, + messages=[message], + stream=True, + ) + ): + print(event) + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/toolchain/inference/generation.py b/toolchain/inference/generation.py new file mode 100644 index 000000000..1a067071b --- /dev/null +++ b/toolchain/inference/generation.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import json +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Generator, List, Optional, TypedDict + +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) +from models.llama3.args import ModelArgs +from models.llama3.chat_format import ChatFormat, ModelInput +from models.llama3.datatypes import Message +from models.llama3.model import Transformer +from models.llama3.tokenizer import Tokenizer +from termcolor import cprint + + +@dataclass +class TokenResult: + token: int + text: str + logprobs: Optional[List[float]] = None + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + model_parallel_size: Optional[int] = None, + seed: int = 1, + ) -> "Llama": + """ + Build a Llama instance by initializing and loading a model checkpoint. + + Args: + ckpt_dir (str): Path to the directory containing checkpoint files. + tokenizer_path (str): Path to the tokenizer file. + max_seq_len (int): Maximum sequence length for input text. + max_batch_size (int): Maximum batch size for inference. + model_parallel_size (Optional[int], optional): Number of model parallel processes. + If not provided, it's determined from the environment. Defaults to None. + + Returns: + Llama: An instance of the Llama class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory, + or if the model parallel size does not match the number of checkpoint files. + + Note: + This method initializes the distributed process group, sets the device to CUDA, + and loads the pre-trained model and tokenizer. + """ + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("nccl") + if not model_parallel_is_initialized(): + if model_parallel_size is None: + model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(model_parallel_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert model_parallel_size == len( + checkpoints + ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + # TODO(ashwin): this block is so we can load internal checkpoints without additional + # fuss. the final code should _not_ have this blurb + if "model" in params: + params = params["model"] + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + + assert ( + model_args.vocab_size == tokenizer.n_words + ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer, model_args) + + def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): + self.args = args + self.model = model + self.tokenizer = tokenizer + self.formatter = ChatFormat(tokenizer) + + @torch.inference_mode() + def generate( + self, + model_input: ModelInput, + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + include_stop_token: bool = False, + ) -> Generator: + params = self.model.params + + # cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red") + prompt_tokens = [model_input.tokens] + + bsz = 1 + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + + if max_prompt_len >= params.max_seq_len: + cprint( + f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red" + ) + return + + total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + if min_prompt_len == total_len: + # TODO(ashwin): unify this branch with the one below and figure out multimodal crap + logits = self.model.forward(tokens, prev_pos) + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + stop_tokens = torch.tensor(self.tokenizer.stop_tokens) + + for cur_pos in range(min_prompt_len, total_len): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + + target = tokens[:, prev_pos + 1 : cur_pos + 1] + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + torch.isin(next_token, stop_tokens) + ) + yield TokenResult( + token=next_token[0].item(), + text=self.tokenizer.decode(next_token.tolist()), + logprobs=( + token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist() + if logprobs + else None + ), + ) + + prev_pos = cur_pos + if all(eos_reached): + break + + def text_completion( + self, + prompt: str, + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> Generator: + if ( + max_gen_len is None + or max_gen_len == 0 + or max_gen_len >= self.model.params.max_seq_len + ): + max_gen_len = self.model.params.max_seq_len - 1 + + prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False) + + yield from self.generate( + model_input=ModelInput(tokens=prompt_tokens), + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + + def chat_completion( + self, + messages: List[Message], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> Generator: + if ( + max_gen_len is None + or max_gen_len == 0 + or max_gen_len >= self.model.params.max_seq_len + ): + max_gen_len = self.model.params.max_seq_len - 1 + + yield from self.generate( + model_input=self.formatter.encode_dialog_prompt(messages), + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + include_stop_token=True, + ) + + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/toolchain/inference/inference.py b/toolchain/inference/inference.py new file mode 100644 index 000000000..7ef32e93b --- /dev/null +++ b/toolchain/inference/inference.py @@ -0,0 +1,173 @@ +from typing import AsyncGenerator + +from models.llama3.datatypes import StopReason + +from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig +from .api.datatypes import ( + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ToolCallDelta, + ToolCallParseStatus, +) +from .api.endpoints import ( + ChatCompletionRequest, + ChatCompletionResponseStreamChunk, + CompletionRequest, + ModelInference, +) +from .model_parallel import LlamaModelParallelGenerator + + +def generator_args_from_config(config: InlineImplConfig) -> GeneratorArgs: + if ( + config.checkpoint_config.checkpoint.checkpoint_type + == CheckpointType.pytorch.value + ): + pt_checkpoint = config.checkpoint_config.checkpoint + return GeneratorArgs( + ckpt_dir=pt_checkpoint.checkpoint_dir, + tokenizer_path=pt_checkpoint.tokenizer_path, + model_parallel_size=pt_checkpoint.model_parallel_size, + max_seq_len=config.max_seq_len, + max_batch_size=config.max_batch_size, + ) + else: + raise NotImplementedError("HF Checkpoint not supported yet") + + +class ModelInferenceImpl(ModelInference): + + def __init__(self, config: InlineImplConfig) -> None: + self.config = config + + async def initialize(self) -> None: + generator_args = generator_args_from_config(self.config) + self.generator = LlamaModelParallelGenerator( + args=generator_args, + ) + self.generator.start() + + async def shutdown(self) -> None: + self.generator.stop() + + async def completion(self, request: CompletionRequest) -> AsyncGenerator: + raise NotImplementedError() + + async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + + tokens = [] + logprobs = [] + + stop_reason = None + + buffer = "" + ipython = False + + for token_result in self.generator.chat_completion( + messages=request.messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + ): + buffer += token_result.text + tokens.append(token_result.token) + + if not ipython and buffer.startswith("<|python_tag|>"): + ipython = True + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), + ) + ) + buffer = buffer[len("<|python_tag|>") :] + continue + + if not request.stream: + if request.logprobs: + logprobs.append(token_result.logprob) + + continue + + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + if ipython: + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) + else: + delta = text + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) + ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + # TODO(ashwin): parse tool calls separately here and report errors? + # if someone breaks the iteration before coming here we are toast + message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) + if request.stream: + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), + stop_reason=stop_reason, + ) + ) + + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) + + # TODO(ashwin): what else do we need to send out here when everything finishes? + else: + yield ChatCompletionResponse( + content=message.content, + tool_calls=message.tool_calls, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) diff --git a/toolchain/inference/model_parallel.py b/toolchain/inference/model_parallel.py new file mode 100644 index 000000000..8be991681 --- /dev/null +++ b/toolchain/inference/model_parallel.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass +from functools import partial +from typing import Generator, List, Optional + +from models.llama3.chat_format import ChatFormat +from models.llama3.datatypes import Message +from models.llama3.tokenizer import Tokenizer + +from .api.config import GeneratorArgs +from .generation import Llama +from .parallel_utils import ModelParallelProcessGroup + + +@dataclass +class InferenceArgs: + messages: List[Message] + temperature: float + top_p: float + max_gen_len: int + logprobs: bool + + +class ModelRunner: + def __init__(self, llama): + self.llama = llama + + # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` + def __call__(self, task: InferenceArgs): + return self.llama.chat_completion( + task.messages, + task.temperature, + task.top_p, + task.max_gen_len, + task.logprobs, + ) + + +def init_model_cb(args: GeneratorArgs): + llama = Llama.build( + args.ckpt_dir, + args.tokenizer_path, + args.max_seq_len, + args.max_batch_size, + ) + return ModelRunner(llama) + + +class LlamaModelParallelGenerator: + """ + This abstraction exists so + - we can run model parallel code without needing to run the CLIs via torchrun + - this also enables use model parallel code within a notebook context. + + A Context Manager is used to ensure that the model parallel process is started and stopped + correctly. This does make the ergonomics a little awkward, because it isn't immediately + clear at the callsite why we need to use a context manager. + """ + + def __init__(self, args: GeneratorArgs): + self.args = args + + # this is a hack because Agent's loop uses this to tokenize and check if input is too long + # while the tool-use loop is going + self.formatter = ChatFormat(Tokenizer(self.args.tokenizer_path)) + + def start(self): + self.__enter__() + + def stop(self): + self.__exit__(None, None, None) + + def __enter__(self): + self.group = ModelParallelProcessGroup( + self.args.model_parallel_size, + init_model_cb=partial(init_model_cb, self.args), + ) + self.group.start() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.group.stop() + + def chat_completion( + self, + messages: List[Message], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> Generator: + req_obj = InferenceArgs( + messages=messages, + temperature=temperature, + top_p=top_p, + max_gen_len=max_gen_len, + logprobs=logprobs, + ) + + gen = self.group.run_inference(req_obj) + yield from gen diff --git a/toolchain/inference/parallel_utils.py b/toolchain/inference/parallel_utils.py new file mode 100644 index 000000000..daa061792 --- /dev/null +++ b/toolchain/inference/parallel_utils.py @@ -0,0 +1,259 @@ +import multiprocessing +import os +import pickle +import tempfile +import time +import uuid + +from typing import Callable, Generator + +import torch + +import zmq + +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_group, + get_model_parallel_rank, + get_model_parallel_src_rank, +) + +from torch.distributed.launcher.api import elastic_launch, LaunchConfig + + +_END_SENTINEL = "__end_sentinel__" +_CANCEL_SENTINEL = "__cancel_sentinel__" + + +def mp_rank_0() -> bool: + return get_model_parallel_rank() == 0 + + +def retrieve_requests(reply_socket_url: str): + if mp_rank_0(): + context = zmq.Context() + reply_socket = context.socket(zmq.ROUTER) + reply_socket.connect(reply_socket_url) + + while True: + client_id, obj = maybe_get_work(reply_socket) + if obj is None: + time.sleep(0.01) + continue + + reply_socket.send_multipart([client_id, pickle.dumps("YES READY")]) + break + + def send_obj(obj): + reply_socket.send_multipart([client_id, pickle.dumps(obj)]) + + while True: + tasks = [None] + if mp_rank_0(): + client_id, task = maybe_get_work(reply_socket) + # there is still an unknown unclean GeneratorExit happening resulting in a + # cancel sentinel getting queued _after_ we have finished sending everything :/ + # kind of a hack this is :/ + if task != _CANCEL_SENTINEL: + tasks = [task] + + torch.distributed.broadcast_object_list( + tasks, + src=get_model_parallel_src_rank(), + group=get_model_parallel_group(), + ) + + task = tasks[0] + if task is None: + time.sleep(0.1) + else: + try: + out = yield task + if out is None: + break + + for obj in out: + updates = [None] + if mp_rank_0(): + _, update = maybe_get_work(reply_socket) + if update == _CANCEL_SENTINEL: + updates = [update] + else: + # only send the update if it's not cancelled otherwise the object sits in the socket + # and gets pulled in the next request lol + send_obj(obj) + + torch.distributed.broadcast_object_list( + updates, + src=get_model_parallel_src_rank(), + group=get_model_parallel_group(), + ) + if updates[0] == _CANCEL_SENTINEL: + print("quitting generation loop because request was cancelled") + break + + if mp_rank_0(): + send_obj(_END_SENTINEL) + except Exception as e: + print(f"[debug] got exception {e}") + import traceback + + traceback.print_exc() + if mp_rank_0(): + send_obj(e) + + if mp_rank_0(): + send_obj("DONE") + + +def maybe_get_work(sock: zmq.Socket): + message = None + client_id = None + try: + client_id, obj = sock.recv_multipart(zmq.NOBLOCK) + message = pickle.loads(obj) + except zmq.ZMQError as e: + if e.errno != zmq.EAGAIN: + raise e + + return client_id, message + + +def worker_process_entrypoint( + reply_socket_url: str, + init_model_cb: Callable, +) -> None: + model = init_model_cb() + torch.distributed.barrier() + time.sleep(1) + + # run the requests co-routine which retrieves requests from the socket + # and sends responses (we provide) back to the caller + req_gen = retrieve_requests(reply_socket_url) + result = None + while True: + try: + task = req_gen.send(result) + if isinstance(task, str) and task == _END_SENTINEL: + break + + result = model(task) + except StopIteration: + break + + print("[debug] worker process done") + + +def launch_dist_group( + reply_socket_url: str, + model_parallel_size: int, + init_model_cb: Callable, + **kwargs, +) -> None: + id = uuid.uuid4().hex + dist_url = f"file:///tmp/llama3_{id}_{time.time()}" + + with tempfile.TemporaryDirectory() as tmpdir: + # TODO: track workers and if they terminate, tell parent process about it so cleanup can happen + launch_config = LaunchConfig( + max_nodes=1, + min_nodes=1, + nproc_per_node=model_parallel_size, + start_method="fork", + rdzv_backend="c10d", + rdzv_endpoint=os.path.join(tmpdir, "rdzv"), + rdzv_configs={"store_type": "file", "timeout": 90}, + max_restarts=0, + monitor_interval=1, + run_id=str(uuid.uuid4()), + ) + elastic_launch(launch_config, entrypoint=worker_process_entrypoint)( + reply_socket_url, + init_model_cb, + ) + + +def start_model_parallel_process( + model_parallel_size: int, + init_model_cb: Callable, + **kwargs, +): + context = zmq.Context() + request_socket = context.socket(zmq.DEALER) + + # Binding the request socket to a random port + request_socket.bind("tcp://127.0.0.1:0") + + main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT) + + ctx = multiprocessing.get_context("fork") + process = ctx.Process( + target=launch_dist_group, + args=( + main_process_url, + model_parallel_size, + init_model_cb, + ), + kwargs=kwargs, + ) + process.start() + + # wait until the model is loaded; rank 0 will send a message to indicate it's ready + + request_socket.send_pyobj("READY?") + response = request_socket.recv_pyobj() + print(f"Finished model load {response}") + + return request_socket, process + + +class ModelParallelProcessGroup: + def __init__( + self, + model_parallel_size: int, + init_model_cb: Callable, + **kwargs, + ): + self.model_parallel_size = model_parallel_size + self.init_model_cb = init_model_cb + self.started = False + self.running = False + + def start(self): + assert not self.started, "process group already started" + self.request_socket, self.process = start_model_parallel_process( + self.model_parallel_size, + self.init_model_cb, + ) + self.started = True + + def stop(self): + assert self.started, "process group not started" + if self.process.is_alive(): + self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK) + self.process.join() + self.started = False + + def run_inference(self, request) -> Generator: + assert not self.running, "inference already running" + + self.running = True + self.request_socket.send_pyobj(request) + try: + while True: + obj = self.request_socket.recv_pyobj() + if obj == _END_SENTINEL: + break + + if isinstance(obj, Exception): + print(f"[debug] got exception {obj}") + raise obj + + yield obj + except GeneratorExit as e: + self.request_socket.send_pyobj(_CANCEL_SENTINEL) + while True: + obj = self.request_socket.recv_pyobj() + if obj == _END_SENTINEL: + break + finally: + self.running = False diff --git a/toolchain/inference/quantization/build_conda.sh b/toolchain/inference/quantization/build_conda.sh new file mode 100644 index 000000000..624f6e831 --- /dev/null +++ b/toolchain/inference/quantization/build_conda.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +if [[ $# -ne 1 ]]; then + echo "Error: Please provide the name of CONDA environment you wish to create" + exit 1 +fi + +ENV_NAME=$1 + +set -eu +eval "$(conda shell.bash hook)" + +echo "Will build env (or overwrite) named '$ENV_NAME'" + +set -x + +run_build() { + # Set CUDA 9.0a targets + export CUDA_ARCH_LIST="8.0;9.0a" + export NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" + export TORCH_CUDA_ARCH_LIST=$CUDA_ARCH_LIST + + # Set up the conda environment + yes | conda remove --name $ENV_NAME --all + yes | conda create -n $ENV_NAME python=3.10 + conda activate $ENV_NAME + yes | conda install --channel "nvidia/label/cuda-12.1.0" cuda + yes | conda install cuda-nvtx cuda-nvtx-dev conda-forge::nccl + + + # ############# Hack to get CUDA path ############# + ln -s $CONDA_PREFIX/targets/x86_64-linux/include/* $CONDA_PREFIX/include/ || true + export CUDA_HOME=$CONDA_PREFIX + export CUDA_BIN_PATH=$CUDA_HOME + # ################################################# + + # PT nightly + pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 + + # install dependencies for `llama-agentic-system` + pip install -r fp8_requirements.txt +} + +run_build diff --git a/toolchain/inference/quantization/fp8_impls.py b/toolchain/inference/quantization/fp8_impls.py new file mode 100644 index 000000000..095039b24 --- /dev/null +++ b/toolchain/inference/quantization/fp8_impls.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import collections +from enum import Enum, unique +from typing import Optional, Type + +try: + import fbgemm_gpu.experimental.gen_ai # noqa: F401 + + print("Using efficient FP8 operators in FBGEMM.") +except (ImportError, ModuleNotFoundError): + print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.") + +import torch +from torch import nn, Tensor + + +@unique +class FfnQuantizeMode(Enum): + FP8_ROWWISE = "fp8_rowwise" + NONE = "none" + + def __str__(self) -> str: + return self.value + + +class Fp8ScaledWeights: + # TODO: Ugly trick so torch allows us to replace parameters + # with our custom Fp8Weights instance. Do this properly. + @property + def __class__(self) -> Type[nn.parameter.Parameter]: + return nn.Parameter + + @property + def grad_fn(self) -> None: + return None + + +# pyre-fixme[4]: Attribute annotation cannot be `Any`. +# pyre-fixme[2]: Parameter annotation cannot be `Any`. +class Fp8RowwiseWeights( + Fp8ScaledWeights, + collections.namedtuple( + "Fp8RowwiseWeights", + ["weight", "scale", "shape", "activation_scale_ub"], + ), +): + pass + + +def ffn_swiglu( + x: Tensor, + w1: Fp8RowwiseWeights, + w3: Fp8RowwiseWeights, + w2: Fp8RowwiseWeights, + num_tokens: Optional[Tensor] = None, + is_memory_bounded: bool = False, +) -> Tensor: + if ( + isinstance(w1, Fp8ScaledWeights) + and isinstance(w3, Fp8ScaledWeights) + and isinstance(w2, Fp8ScaledWeights) + ): + return ffn_swiglu_fp8_dynamic( + x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded + ) + + (B, T, D) = x.shape + (HD_L, D_) = w1.shape + assert D_ == D + + assert isinstance(w1, Tensor) + assert isinstance(w3, Tensor) + x1 = x.view(B * T, D) @ w1.T + x2 = x.view(B * T, D) @ w3.T + z = torch.nn.functional.silu(x1) * x2 + del x1, x2 + assert isinstance(w2, Tensor) + return (z @ w2.T).view(B, T, D) + + +@torch.inference_mode() +def quantize_fp8( + w: Tensor, + fp8_activation_scale_ub: float, + mode: Optional[FfnQuantizeMode] = None, + output_device: Optional[torch.device] = None, +) -> Fp8RowwiseWeights: + """Quantize [n, k] weight tensor. + + Args: + w (Tensor): [n, k] input high precision tensor to quantize. + fp8_activation_scale_ub (float): Upper bound for activation max. + mode (FfnQuantizeMode): Quantization mode. + """ + activation_scale_ub = torch.tensor( + [fp8_activation_scale_ub], + dtype=torch.float, + device="cuda", + ) + if mode is not None and mode == FfnQuantizeMode.FP8_ROWWISE: # rowwise + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + del w + return Fp8RowwiseWeights( + weight=wq, + scale=w_scale, + shape=wq.shape, + activation_scale_ub=activation_scale_ub, + ) + + +def fc_fp8_dynamic( + x: Tensor, + w: Fp8RowwiseWeights, + activation_scale_ub: Optional[Tensor] = None, + num_tokens: Optional[Tensor] = None, + is_memory_bounded: bool = False, +) -> Tensor: + """ + Single w8a8 fc layer with dynamic row-wise scaling. + """ + if isinstance(w, Fp8RowwiseWeights): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x, num_tokens, activation_scale_ub + ) + y = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, w.weight, x_scale, w.scale, use_fast_accum=True + ) + del xq + return y + + +def ffn_swiglu_fp8_dynamic( + x: Tensor, + w1: Fp8RowwiseWeights, + w3: Fp8RowwiseWeights, + w2: Fp8RowwiseWeights, + activation_scale_ub: Optional[Tensor] = None, + num_tokens: Optional[Tensor] = None, + is_memory_bounded: bool = False, +) -> Tensor: + (B, T, D) = x.shape + HD_L = w1.shape[0] + assert HD_L == w3.shape[0] + x1 = fc_fp8_dynamic( + x.view(B * T, D), + w1, + activation_scale_ub, + num_tokens, + is_memory_bounded, + ) + x2 = fc_fp8_dynamic( + x.view(B * T, D), + w3, + activation_scale_ub, + num_tokens, + is_memory_bounded, + ) + z = torch.nn.functional.silu(x1) * x2 + del x1, x2 + + z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded) + + return z_.view(B, T, D) diff --git a/toolchain/inference/quantization/fp8_requirements.txt b/toolchain/inference/quantization/fp8_requirements.txt new file mode 100644 index 000000000..dfae3b092 --- /dev/null +++ b/toolchain/inference/quantization/fp8_requirements.txt @@ -0,0 +1,5 @@ +fairscale +fire +tiktoken +blobfile +fbgemm-gpu==0.8.0rc4 diff --git a/toolchain/inference/quantization/generation.py b/toolchain/inference/quantization/generation.py new file mode 100644 index 000000000..ed70485d9 --- /dev/null +++ b/toolchain/inference/quantization/generation.py @@ -0,0 +1,455 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import json +import os +import sys +import time +from pathlib import Path +from typing import List, Optional, Tuple, TypedDict + +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) +from fp8.fp8_impls import ( + FfnQuantizeMode, + Fp8ScaledWeights, + load_fp8, + ModelLoadMode, + quantize_fp8, +) + +from llama.model import ModelArgs, Transformer, TransformerBlock +from llama.tokenizer import ChatFormat, Dialog, Message, ModelInput, Tokenizer + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class ChatPrediction(TypedDict, total=False): + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + model_parallel_size: Optional[int] = None, + ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.NONE, + model_load_mode: Optional[ModelLoadMode] = ModelLoadMode.BF16, + fp8_activation_scale_ub: Optional[float] = 1200.0, + seed: int = 1, + ) -> "Llama": + """ + Build a Llama instance by initializing and loading a model checkpoint. + + Args: + ckpt_dir (str): Path to the directory containing checkpoint files. + tokenizer_path (str): Path to the tokenizer file. + max_seq_len (int): Maximum sequence length for input text. + max_batch_size (int): Maximum batch size for inference. + model_parallel_size (Optional[int], optional): Number of model parallel processes. + If not provided, it's determined from the environment. Defaults to None. + + Returns: + Llama: An instance of the Llama class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory, + or if the model parallel size does not match the number of checkpoint files. + + Note: + This method initializes the distributed process group, sets the device to CUDA, + and loads the pre-trained model and tokenizer. + """ + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("nccl") + if not model_parallel_is_initialized(): + if model_parallel_size is None: + model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(model_parallel_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert model_parallel_size == len( + checkpoints + ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + assert ( + model_args.vocab_size == tokenizer.n_words + ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" + + # load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype + torch.set_default_tensor_type(torch.BFloat16Tensor) + + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + + print("ffn_quantize_mode: ", ffn_quantize_mode) + if ffn_quantize_mode == FfnQuantizeMode.FP8_ROWWISE: + # Move weights to GPU with quantization + if model_load_mode == ModelLoadMode.FP8: + fp8_scales_path = os.path.join( + ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" + ) + assert os.path.isfile( + fp8_scales_path + ), f"fp8_scales_path not found for rank {get_model_parallel_rank()}" + fp8_scales = torch.load(fp8_scales_path, weights_only=True) + + for block in model.layers: + if isinstance(block, TransformerBlock): + if block.layer_id == 0 or block.layer_id == ( + model.n_layers - 1 + ): + continue + block.feed_forward.w1.weight = load_fp8( + block.feed_forward.w1.weight, + fp8_scales[ + f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}" + ], + fp8_activation_scale_ub, + ) + block.feed_forward.w3.weight = load_fp8( + block.feed_forward.w3.weight, + fp8_scales[ + f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}" + ], + fp8_activation_scale_ub, + ) + block.feed_forward.w2.weight = load_fp8( + block.feed_forward.w2.weight, + fp8_scales[ + f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}" + ], + fp8_activation_scale_ub, + ) + else: + for block in model.layers: + if isinstance(block, TransformerBlock): + if block.layer_id == 0 or block.layer_id == ( + model.n_layers - 1 + ): + continue + block.feed_forward.w1.weight = quantize_fp8( + block.feed_forward.w1.weight, + fp8_activation_scale_ub, + ffn_quantize_mode, + output_device=torch.device("cuda"), + ) + block.feed_forward.w3.weight = quantize_fp8( + block.feed_forward.w3.weight, + fp8_activation_scale_ub, + ffn_quantize_mode, + output_device=torch.device("cuda"), + ) + block.feed_forward.w2.weight = quantize_fp8( + block.feed_forward.w2.weight, + fp8_activation_scale_ub, + ffn_quantize_mode, + output_device=torch.device("cuda"), + ) + + for _, parameter in model.named_parameters(): + if not isinstance(parameter, Fp8ScaledWeights): + parameter.data = parameter.to(device="cuda") + else: + for _, parameter in model.named_parameters(): + parameter.data = parameter.to(device="cuda") + + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer, model_args) + + def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): + self.args = args + self.model = model + self.tokenizer = tokenizer + self.formatter = ChatFormat(tokenizer) + + @torch.inference_mode() + def generate( + self, + model_inputs: List[ModelInput], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + include_stop_token: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + params = self.model.params + prompt_tokens = [m.tokens for m in model_inputs] + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + if min_prompt_len == total_len: + logits = self.model.forward(tokens, prev_pos) + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)) + + for cur_pos in range(min_prompt_len, total_len): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + + target = tokens[:, prev_pos + 1 : cur_pos + 1] + + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + torch.isin(next_token, stop_tokens) + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to after eos tok if any + for stop_token in self.tokenizer.stop_tokens: + try: + eos_idx = toks.index(stop_token) + if include_stop_token: + eos_idx += 1 + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + except ValueError: + pass + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + + def text_completion( + self, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> List[CompletionPrediction]: + """ + Perform text completion for a list of prompts using the language generation model. + + Args: + prompts (List[str]): List of text prompts for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. + + Note: + This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + generation_tokens, generation_logprobs = self.generate( + model_inputs=[ModelInput(tokens=pt) for pt in prompt_tokens], + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + if logprobs: + return [ + { + "generation": self.tokenizer.decode(t), + "tokens": [self.tokenizer.decode([x]) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] + + def chat_completion( + self, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + + model_inputs = [ + self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs + ] + generation_tokens, generation_logprobs = self.generate( + model_inputs=model_inputs, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + include_stop_token=True, + ) + if logprobs: + return [ + { + "generation": self.formatter.decode_assistant_message(t), + "tokens": [self.tokenizer.decode([x]) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [ + { + "generation": self.formatter.decode_assistant_message(t), + } + for t in generation_tokens + ] + + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/toolchain/inference/quantization/model.py b/toolchain/inference/quantization/model.py new file mode 100644 index 000000000..ce806e697 --- /dev/null +++ b/toolchain/inference/quantization/model.py @@ -0,0 +1,355 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region +from fp8.fp8_impls import ffn_swiglu +from torch import nn + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + use_scaled_rope: bool = False + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + + if self.n_kv_heads is None: + self.n_kv_heads = self.n_heads + assert self.n_kv_heads <= self.n_heads + assert self.n_heads % self.n_kv_heads == 0 + assert self.dim % self.n_heads == 0 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + + self.cache_k = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + self.cache_v = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + keys, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv( + values, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear( + dim, + hidden_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.w3 = ColumnParallelLinear( + dim, + hidden_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + + def forward(self, x): + out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight) + return reduce_from_model_parallel_region(out) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = VocabParallelEmbedding( + params.vocab_size, params.dim, init_method=lambda x: x + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, init_method=lambda x: x + ) + + self.freqs_cis = precompute_freqs_cis( + params.dim // params.n_heads, + params.max_seq_len * 2, + params.rope_theta, + params.use_scaled_rope, + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack( + [torch.zeros((seqlen, start_pos), device=tokens.device), mask] + ).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output diff --git a/toolchain/inference/quantization/quantize_checkpoint.py b/toolchain/inference/quantization/quantize_checkpoint.py new file mode 100644 index 000000000..6fe66e13f --- /dev/null +++ b/toolchain/inference/quantization/quantize_checkpoint.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import json +import os +import shutil +import sys +from pathlib import Path +from typing import Optional + +import fire + +import torch +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) +from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8 + +from llama.model import ModelArgs, Transformer, TransformerBlock +from llama.tokenizer import Tokenizer +from torch.nn.parameter import Parameter + + +def main( + ckpt_dir: str, + tokenizer_path: str, + quantized_ckpt_dir: str, + max_seq_len: Optional[int] = 512, + max_batch_size: Optional[int] = 4, + model_parallel_size: Optional[int] = None, + ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE, + fp8_activation_scale_ub: Optional[float] = 1200.0, + seed: int = 1, +): + """ """ + if not os.path.exists(quantized_ckpt_dir): + os.makedirs(quantized_ckpt_dir) + shutil.copy( + os.path.join(ckpt_dir, "params.json"), + os.path.join(quantized_ckpt_dir, "params.json"), + ) + shutil.copy( + os.path.join(ckpt_dir, "tokenizer.model"), + os.path.join(quantized_ckpt_dir, "tokenizer.model"), + ) + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("nccl") + if not model_parallel_is_initialized(): + if model_parallel_size is None: + model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(model_parallel_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert model_parallel_size == len( + checkpoints + ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + assert ( + model_args.vocab_size == tokenizer.n_words + ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" + + # load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype + torch.set_default_tensor_type(torch.BFloat16Tensor) + + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + + print(ckpt_path) + assert ( + quantized_ckpt_dir is not None + ), "QUantized checkpoint directory should not be None" + fp8_scales = {} + for block in model.layers: + if isinstance(block, TransformerBlock): + if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): + continue + + fp8_weight = quantize_fp8( + block.feed_forward.w1.weight, + fp8_activation_scale_ub, + ffn_quantize_mode, + output_device=torch.device("cpu"), + ) + with torch.inference_mode(): + block.feed_forward.w1.weight = Parameter(fp8_weight.weight) + fp8_scales[ + f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}" + ] = fp8_weight.scale + + fp8_weight = quantize_fp8( + block.feed_forward.w3.weight, + fp8_activation_scale_ub, + ffn_quantize_mode, + output_device=torch.device("cpu"), + ) + with torch.inference_mode(): + block.feed_forward.w3.weight = Parameter(fp8_weight.weight) + fp8_scales[ + f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}" + ] = fp8_weight.scale + + fp8_weight = quantize_fp8( + block.feed_forward.w2.weight, + fp8_activation_scale_ub, + ffn_quantize_mode, + output_device=torch.device("cpu"), + ) + with torch.inference_mode(): + block.feed_forward.w2.weight = Parameter(fp8_weight.weight) + fp8_scales[ + f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}" + ] = fp8_weight.scale + + fp8_scales_path = os.path.join( + quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" + ) + torch.save(fp8_scales, fp8_scales_path) + + ckpt_path = os.path.join( + quantized_ckpt_dir, + "consolidated.{:02d}.pth".format(get_model_parallel_rank()), + ) + torch.save(model.state_dict(), ckpt_path) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/toolchain/inference/quantization/run_quantize_checkpoint.sh b/toolchain/inference/quantization/run_quantize_checkpoint.sh new file mode 100755 index 000000000..a61180907 --- /dev/null +++ b/toolchain/inference/quantization/run_quantize_checkpoint.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -euo pipefail +set -x + +cd $(git rev-parse --show-toplevel) + +MASTER_HOST=$1 +RUN_ID=$2 +CKPT_DIR=$3 +QUANT_CKPT_DIR=$4 +TOKENIZER_PATH=$5 +NNODES=$6 +NPROC=$7 + +echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR + +NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \ + torchrun \ + --nnodes=$NNODES --nproc_per_node=$NPROC \ + --rdzv_id=$RUN_ID \ + --rdzv_conf='timeout=120' \ + --rdzv_backend=c10d \ + --rdzv_endpoint="${MASTER_HOST}:29502" \ + quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR diff --git a/toolchain/inference/quantization/test_fp8.py b/toolchain/inference/quantization/test_fp8.py new file mode 100644 index 000000000..3e6f75213 --- /dev/null +++ b/toolchain/inference/quantization/test_fp8.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import unittest + +import torch + +from fp8_impls import attn_linear, ffn_swiglu_fp8_dynamic, quantize_fp8 +from hypothesis import given, settings, strategies as st +from torch import Tensor + + +@unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, + "Skip when H100 is not available", +) +class FP8Tests(unittest.TestCase): + @settings(deadline=None) + @given( + D=st.sampled_from([4096, 8192]), + HD_L=st.sampled_from([1280, 2560]), + B=st.sampled_from([1, 2]), + T=st.sampled_from([2048, 4096]), + UB=st.sampled_from([1000, 10000]), + ) + def test_fp8_ffn( + self, + D: int, + HD_L: int, + B: int, + T: int, + UB: float, + ) -> None: + x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1 + w13 = ( + torch.randn(size=(2 * HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 + ) + w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1 + + x_q = quantize_fp8(x, UB) + w13_q = quantize_fp8(w13, UB) + w2_q = quantize_fp8(w2, UB) + + def ref_ffn(x: Tensor, w13: Tensor, w2: Tensor) -> Tensor: + (B, T, D) = x.shape + (HD_L_2, D_) = w13.shape + assert D_ == D + HD_L = HD_L_2 // 2 + + y = x.view(B * T, D) @ w13.T + x1 = y[:, :HD_L] + x2 = y[:, HD_L:] + + z = torch.nn.functional.silu(x1) * x2 + return (z @ w2.T).view(B, T, D).to(torch.bfloat16) + + v = ffn_swiglu_fp8_dynamic(x, w13_q, w2_q) + + # Fake quant + x = x_q.weight.bfloat16() * x_q.scale + w13 = w13_q.weight.bfloat16() * w13_q.scale + w2 = w2_q.weight.bfloat16() * w2_q.scale + + v_ref = ref_ffn(x, w13, w2) + + torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3) + + @settings(deadline=None) + @given( + B_T=st.sampled_from([2048, 4096]), + D=st.sampled_from([128, 256]), + HD_L=st.sampled_from([256, 512]), + UB=st.sampled_from([1000, 10000]), + ) + def test_fp8_attn_linear(self, B_T: int, D: int, HD_L: int, UB: int) -> None: + B_T = 4096 + D = 256 + HD_L = 512 + UB = float(UB) + x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1 + wqkv = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 + + x_q = quantize_fp8(x, UB) + wqkv_q = quantize_fp8(wqkv, UB) + + num_tokens = torch.tensor(B_T, dtype=torch.int64, device="cuda") + + y = attn_linear(x, wqkv_q) + y_nt = attn_linear(x, wqkv_q, num_tokens=num_tokens) + + # Fake quant + x = x_q.weight.bfloat16() * x_q.scale + wqkv = wqkv_q.weight.bfloat16() * wqkv_q.scale + y_ref = (x @ wqkv.T).to(torch.bfloat16) + + torch.testing.assert_close(y_ref, y, atol=1.0e-3, rtol=1.0e-3) + torch.testing.assert_close(y_ref, y_nt, atol=1.0e-3, rtol=1.0e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/toolchain/inference/server.py b/toolchain/inference/server.py new file mode 100644 index 000000000..52aac3dda --- /dev/null +++ b/toolchain/inference/server.py @@ -0,0 +1,117 @@ +import asyncio +import signal + +import fire + +from dotenv import load_dotenv + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse + +from omegaconf import OmegaConf + +from toolchain.utils import get_config_dir, parse_config +from .api.config import ModelInferenceHydraConfig +from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk + +from .api_instance import get_inference_api_instance + + +load_dotenv() + + +GLOBAL_CONFIG = None + + +def get_config(): + return GLOBAL_CONFIG + + +def handle_sigint(*args, **kwargs): + print("SIGINT or CTRL-C detected. Exiting gracefully", args) + loop = asyncio.get_event_loop() + for task in asyncio.all_tasks(loop): + task.cancel() + loop.stop() + + +app = FastAPI() + + +@app.on_event("startup") +async def startup(): + global InferenceApiInstance + + config = get_config() + hydra_config = ModelInferenceHydraConfig( + **OmegaConf.to_container(config["model_inference_config"], resolve=True) + ) + model_inference_config = hydra_config.convert_to_model_inferene_config() + + InferenceApiInstance = await get_inference_api_instance( + model_inference_config, + ) + await InferenceApiInstance.initialize() + + +@app.on_event("shutdown") +async def shutdown(): + global InferenceApiInstance + + print("shutting down") + await InferenceApiInstance.shutdown() + + +# there's a single model parallel process running serving the model. for now, +# we don't support multiple concurrent requests to this process. +semaphore = asyncio.Semaphore(1) + + +@app.post( + "/inference/chat_completion", response_model=ChatCompletionResponseStreamChunk +) +def chat_completion(request: Request, exec_request: ChatCompletionRequest): + if semaphore.locked(): + raise HTTPException( + status_code=429, + detail="Only a single concurrent request allowed right now.", + ) + + async def sse_generator(event_gen): + try: + async for event in event_gen: + yield f"data: {event.json()}\n\n" + await asyncio.sleep(0.01) + except asyncio.CancelledError: + print("Generator cancelled") + await event_gen.aclose() + finally: + semaphore.release() + + async def event_gen(): + async for event in InferenceApiInstance.chat_completion(exec_request): + yield event + + return StreamingResponse( + sse_generator(event_gen()), + media_type="text/event-stream", + ) + + +def main(config_path: str, port: int = 5000, disable_ipv6: bool = False): + global GLOBAL_CONFIG + config_dir = get_config_dir() + GLOBAL_CONFIG = parse_config(config_dir, config_path) + + signal.signal(signal.SIGINT, handle_sigint) + + import uvicorn + + # FYI this does not do hot-reloads + listen_host = "::" if not disable_ipv6 else "0.0.0.0" + print(f"Listening on {listen_host}:{port}") + uvicorn.run(app, host=listen_host, port=port) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/toolchain/memory/api/__init__.py b/toolchain/memory/api/__init__.py new file mode 100644 index 000000000..38413ff60 --- /dev/null +++ b/toolchain/memory/api/__init__.py @@ -0,0 +1,2 @@ +from .datatypes import * # noqa: F401 F403 +from .endpoints import * # noqa: F401 F403 diff --git a/toolchain/memory/api/datatypes.py b/toolchain/memory/api/datatypes.py new file mode 100644 index 000000000..0969203f6 --- /dev/null +++ b/toolchain/memory/api/datatypes.py @@ -0,0 +1,19 @@ +from typing import Any, Dict + +from pydantic import BaseModel + +from strong_typing.schema import json_schema_type + + +@json_schema_type +class MemoryBank(BaseModel): + memory_bank_id: str + memory_bank_name: str + + +@json_schema_type +class MemoryBankDocument(BaseModel): + document_id: str + content: bytes + metadata: Dict[str, Any] + mime_type: str diff --git a/toolchain/memory/api/endpoints.py b/toolchain/memory/api/endpoints.py new file mode 100644 index 000000000..441c1d777 --- /dev/null +++ b/toolchain/memory/api/endpoints.py @@ -0,0 +1,55 @@ +from typing import List, Protocol + +from pyopenapi import webmethod + +from .datatypes import * # noqa: F403 + + +class MemoryBanks(Protocol): + @webmethod(route="/memory_banks/create") + def post_create_memory_bank( + self, + bank_id: str, + bank_name: str, + documents: List[MemoryBankDocument], + ) -> None: ... + + @webmethod(route="/memory_banks/list") + def get_memory_banks(self) -> List[MemoryBank]: ... + + @webmethod(route="/memory_banks/get") + def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ... + + @webmethod(route="/memory_banks/drop") + def delete_memory_bank( + self, + bank_id: str, + ) -> str: ... + + @webmethod(route="/memory_bank/insert") + def post_insert_memory_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ) -> None: ... + + @webmethod(route="/memory_bank/update") + def post_update_memory_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ) -> None: ... + + @webmethod(route="/memory_bank/get") + def get_memory_documents( + self, + bank_id: str, + document_uuids: List[str], + ) -> List[MemoryBankDocument]: ... + + @webmethod(route="/memory_bank/delete") + def delete_memory_documents( + self, + bank_id: str, + document_uuids: List[str], + ) -> List[str]: ... diff --git a/toolchain/post_training/api/__init__.py b/toolchain/post_training/api/__init__.py new file mode 100644 index 000000000..38413ff60 --- /dev/null +++ b/toolchain/post_training/api/__init__.py @@ -0,0 +1,2 @@ +from .datatypes import * # noqa: F401 F403 +from .endpoints import * # noqa: F401 F403 diff --git a/toolchain/post_training/api/datatypes.py b/toolchain/post_training/api/datatypes.py new file mode 100644 index 000000000..50b491c73 --- /dev/null +++ b/toolchain/post_training/api/datatypes.py @@ -0,0 +1,88 @@ +from enum import Enum +from typing import List + +from pydantic import BaseModel + +from strong_typing.schema import json_schema_type + + +class OptimizerType(Enum): + adam = "adam" + adamw = "adamw" + sgd = "sgd" + + +@json_schema_type +class OptimizerConfig(BaseModel): + optimizer_type: OptimizerType + lr: float + lr_min: float + weight_decay: float + + +@json_schema_type +class TrainingConfig(BaseModel): + n_epochs: int + batch_size: int + shuffle: bool + n_iters: int + + enable_activation_checkpointing: bool + memory_efficient_fsdp_wrap: bool + fsdp_cpu_offload: bool + + +@json_schema_type +class FinetuningAlgorithm(Enum): + full = "full" + lora = "lora" + qlora = "qlora" + dora = "dora" + + +@json_schema_type +class LoraFinetuningConfig(BaseModel): + lora_attn_modules: List[str] + apply_lora_to_mlp: bool + apply_lora_to_output: bool + rank: int + alpha: int + + +@json_schema_type +class QLoraFinetuningConfig(LoraFinetuningConfig): + pass + + +@json_schema_type +class DoraFinetuningConfig(LoraFinetuningConfig): + pass + + +@json_schema_type +class PostTrainingJobLogStream(BaseModel): + """Stream of logs from a finetuning job.""" + + job_uuid: str + log_lines: List[str] + + +@json_schema_type +class PostTrainingJobStatus(Enum): + running = "running" + completed = "completed" + failed = "failed" + scheduled = "scheduled" + + +@json_schema_type +class RLHFAlgorithm(Enum): + dpo = "dpo" + + +@json_schema_type +class DPOAlignmentConfig(BaseModel): + reward_scale: float + reward_clip: float + epsilon: float + gamma: float diff --git a/toolchain/post_training/api/endpoints.py b/toolchain/post_training/api/endpoints.py new file mode 100644 index 000000000..8bcaafc3b --- /dev/null +++ b/toolchain/post_training/api/endpoints.py @@ -0,0 +1,123 @@ +from datetime import datetime + +from typing import Any, Dict, List, Optional, Protocol + +from pydantic import BaseModel, Field + +from pyopenapi import webmethod +from strong_typing.schema import json_schema_type + +from models.llama3.datatypes import * # noqa: F403 +from toolchain.dataset.api.datatypes import * # noqa: F403 +from toolchain.common.training_types import * # noqa: F403 +from .datatypes import * # noqa: F403 + + +@json_schema_type +class PostTrainingSFTRequest(BaseModel): + """Request to finetune a model.""" + + job_uuid: str + + model: PretrainedModel + dataset: TrainEvalDataset + validation_dataset: TrainEvalDataset + + algorithm: FinetuningAlgorithm + algorithm_config: Union[ + LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig + ] + + optimizer_config: OptimizerConfig + training_config: TrainingConfig + + # TODO: define these + hyperparam_search_config: Dict[str, Any] + logger_config: Dict[str, Any] + + +@json_schema_type +class PostTrainingRLHFRequest(BaseModel): + """Request to finetune a model.""" + + job_uuid: str + + finetuned_model: URL + + dataset: TrainEvalDataset + validation_dataset: TrainEvalDataset + + algorithm: RLHFAlgorithm + algorithm_config: Union[DPOAlignmentConfig] + + optimizer_config: OptimizerConfig + training_config: TrainingConfig + + # TODO: define these + hyperparam_search_config: Dict[str, Any] + logger_config: Dict[str, Any] + + +class PostTrainingJob(BaseModel): + + job_uuid: str + + +@json_schema_type +class PostTrainingJobStatusResponse(BaseModel): + """Status of a finetuning job.""" + + job_uuid: str + status: PostTrainingJobStatus + + scheduled_at: Optional[datetime] = None + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + resources_allocated: Optional[Dict[str, Any]] = None + + checkpoints: List[Checkpoint] = Field(default_factory=list) + + +@json_schema_type +class PostTrainingJobArtifactsResponse(BaseModel): + """Artifacts of a finetuning job.""" + + job_uuid: str + checkpoints: List[Checkpoint] = Field(default_factory=list) + + # TODO(ashwin): metrics, evals + + +class PostTraining(Protocol): + @webmethod(route="/post_training/supervised_fine_tune") + def post_supervised_fine_tune( + self, + request: PostTrainingSFTRequest, + ) -> PostTrainingJob: ... + + @webmethod(route="/post_training/preference_optimize") + def post_preference_optimize( + self, + request: PostTrainingRLHFRequest, + ) -> PostTrainingJob: ... + + @webmethod(route="/post_training/jobs") + def get_training_jobs(self) -> List[PostTrainingJob]: ... + + # sends SSE stream of logs + @webmethod(route="/post_training/job/logs") + def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... + + @webmethod(route="/post_training/job/status") + def get_training_job_status( + self, job_uuid: str + ) -> PostTrainingJobStatusResponse: ... + + @webmethod(route="/post_training/job/cancel") + def cancel_training_job(self, job_uuid: str) -> None: ... + + @webmethod(route="/post_training/job/artifacts") + def get_training_job_artifacts( + self, job_uuid: str + ) -> PostTrainingJobArtifactsResponse: ... diff --git a/toolchain/reward_scoring/api/__init__.py b/toolchain/reward_scoring/api/__init__.py new file mode 100644 index 000000000..38413ff60 --- /dev/null +++ b/toolchain/reward_scoring/api/__init__.py @@ -0,0 +1,2 @@ +from .datatypes import * # noqa: F401 F403 +from .endpoints import * # noqa: F401 F403 diff --git a/toolchain/reward_scoring/api/datatypes.py b/toolchain/reward_scoring/api/datatypes.py new file mode 100644 index 000000000..c7acdb1a3 --- /dev/null +++ b/toolchain/reward_scoring/api/datatypes.py @@ -0,0 +1,25 @@ +from typing import List + +from pydantic import BaseModel + +from strong_typing.schema import json_schema_type + +from models.llama3.datatypes import * # noqa: F403 + + +@json_schema_type +class ScoredMessage(BaseModel): + message: Message + score: float + + +@json_schema_type +class DialogGenerations(BaseModel): + dialog: List[Message] + sampled_generations: List[Message] + + +@json_schema_type +class ScoredDialogGenerations(BaseModel): + dialog: List[Message] + scored_generations: List[ScoredMessage] diff --git a/toolchain/reward_scoring/api/endpoints.py b/toolchain/reward_scoring/api/endpoints.py new file mode 100644 index 000000000..72de43498 --- /dev/null +++ b/toolchain/reward_scoring/api/endpoints.py @@ -0,0 +1,27 @@ +from typing import List, Protocol, Union +from .datatypes import * # noqa: F403 + +from pyopenapi import webmethod + + +@json_schema_type +class RewardScoringRequest(BaseModel): + """Request to score a reward function. A list of prompts and a list of responses per prompt.""" + + dialog_generations: List[DialogGenerations] + model: RewardModel + + +@json_schema_type +class RewardScoringResponse(BaseModel): + """Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.""" + + scored_generations: List[ScoredDialogGenerations] + + +class RewardScoring(Protocol): + @webmethod(route="/reward_scoring/score") + def post_score( + self, + request: RewardScoringRequest, + ) -> Union[RewardScoringResponse]: ... diff --git a/toolchain/safety/__init__.py b/toolchain/safety/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/toolchain/safety/api/__init__.py b/toolchain/safety/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/toolchain/safety/api/config.py b/toolchain/safety/api/config.py new file mode 100644 index 000000000..4858bc53c --- /dev/null +++ b/toolchain/safety/api/config.py @@ -0,0 +1,19 @@ +from typing import List, Optional + +from pydantic import BaseModel + + +class LlamaGuardShieldConfig(BaseModel): + model_dir: str + excluded_categories: List[str] + disable_input_check: bool = False + disable_output_check: bool = False + + +class PromptGuardShieldConfig(BaseModel): + model_dir: str + + +class SafetyConfig(BaseModel): + llama_guard_shield: Optional[LlamaGuardShieldConfig] = None + prompt_guard_shield: Optional[PromptGuardShieldConfig] = None diff --git a/toolchain/safety/api/datatypes.py b/toolchain/safety/api/datatypes.py new file mode 100644 index 000000000..45866d026 --- /dev/null +++ b/toolchain/safety/api/datatypes.py @@ -0,0 +1,53 @@ +from enum import Enum +from typing import Dict, Optional, Union + +from models.llama3.datatypes import ToolParamDefinition + +from pydantic import BaseModel + +from strong_typing.schema import json_schema_type + +from toolchain.common.deployment_types import RestAPIExecutionConfig + + +@json_schema_type +class BuiltinShield(Enum): + llama_guard = "llama_guard" + prompt_guard = "prompt_guard" + code_scanner_guard = "code_scanner_guard" + third_party_shield = "third_party_shield" + + +ShieldType = Union[BuiltinShield, str] + + +@json_schema_type +class OnViolationAction(Enum): + IGNORE = 0 + WARN = 1 + RAISE = 2 + + +@json_schema_type +class ShieldDefinition(BaseModel): + shield_type: ShieldType + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + on_violation_action: OnViolationAction = OnViolationAction.RAISE + execution_config: Optional[RestAPIExecutionConfig] = None + + +@json_schema_type +class ShieldCall(BaseModel): + call_id: str + shield_type: ShieldType + arguments: Dict[str, str] + + +@json_schema_type +class ShieldResponse(BaseModel): + shield_type: ShieldType + # TODO(ashwin): clean this up + is_violation: bool + violation_type: Optional[str] = None + violation_return_message: Optional[str] = None diff --git a/toolchain/safety/shields/__init__.py b/toolchain/safety/shields/__init__.py new file mode 100644 index 000000000..d9ee5ea38 --- /dev/null +++ b/toolchain/safety/shields/__init__.py @@ -0,0 +1,25 @@ +# supress warnings and spew of logs from hugging face +import transformers + +from .base import ( # noqa: F401 + DummyShield, + OnViolationAction, + ShieldBase, + ShieldResponse, + TextShield, +) +from .code_scanner import CodeScannerShield # noqa: F401 +from .contrib.third_party_shield import ThirdPartyShield # noqa: F401 +from .llama_guard import LlamaGuardShield # noqa: F401 +from .prompt_guard import PromptGuardShield # noqa: F401 +from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401 + +transformers.logging.set_verbosity_error() + +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import warnings + +warnings.filterwarnings("ignore") diff --git a/toolchain/safety/shields/base.py b/toolchain/safety/shields/base.py new file mode 100644 index 000000000..dc7a04879 --- /dev/null +++ b/toolchain/safety/shields/base.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod +from typing import List, Union + +from models.llama3.datatypes import Attachment, Message +from toolchain.safety.api.datatypes import * # noqa: F403 + +CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" + + +class ShieldBase(ABC): + + def __init__( + self, + on_violation_action: OnViolationAction = OnViolationAction.RAISE, + ): + self.on_violation_action = on_violation_action + + @abstractmethod + def get_shield_type(self) -> ShieldType: + raise NotImplementedError() + + @abstractmethod + async def run(self, messages: List[Message]) -> ShieldResponse: + raise NotImplementedError() + + +def message_content_as_str(message: Message) -> str: + def _to_str(content: Union[str, Attachment]) -> str: + if isinstance(content, str): + return content + elif isinstance(content, Attachment): + return f"File: {str(content.url)}" + else: + raise + + if isinstance(message.content, list) or isinstance(message.content, tuple): + return "\n".join([_to_str(c) for c in message.content]) + else: + return _to_str(message.content) + + +# For shields that operate on simple strings +class TextShield(ShieldBase): + def convert_messages_to_text(self, messages: List[Message]) -> str: + return "\n".join([message_content_as_str(m) for m in messages]) + + async def run(self, messages: List[Message]) -> ShieldResponse: + text = self.convert_messages_to_text(messages) + return await self.run_impl(text) + + @abstractmethod + async def run_impl(self, text: str) -> ShieldResponse: + raise NotImplementedError() + + +class DummyShield(TextShield): + + def get_shield_type(self) -> ShieldType: + return "dummy" + + async def run_impl(self, text: str) -> ShieldResponse: + # Dummy return LOW to test e2e + return ShieldResponse( + shield_type=BuiltinShield.third_party_shield, is_violation=False + ) diff --git a/toolchain/safety/shields/code_scanner.py b/toolchain/safety/shields/code_scanner.py new file mode 100644 index 000000000..aa877da68 --- /dev/null +++ b/toolchain/safety/shields/code_scanner.py @@ -0,0 +1,28 @@ +from codeshield.cs import CodeShield +from termcolor import cprint + +from .base import ShieldResponse, TextShield +from toolchain.safety.api.datatypes import * # noqa: F403 + + +class CodeScannerShield(TextShield): + + def get_shield_type(self) -> ShieldType: + return BuiltinShield.code_scanner_guard + + async def run_impl(self, text: str) -> ShieldResponse: + cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta") + result = await CodeShield.scan_code(text) + if result.is_insecure: + return ShieldResponse( + shield_type=BuiltinShield.code_scanner_guard, + is_violation=True, + violation_type=",".join( + [issue.pattern_id for issue in result.issues_found] + ), + violation_return_message="Sorry, I found security concerns in the code.", + ) + else: + return ShieldResponse( + shield_type=BuiltinShield.code_scanner_guard, is_violation=False + ) diff --git a/toolchain/safety/shields/contrib/__init__.py b/toolchain/safety/shields/contrib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/toolchain/safety/shields/contrib/third_party_shield.py b/toolchain/safety/shields/contrib/third_party_shield.py new file mode 100644 index 000000000..13ee9556b --- /dev/null +++ b/toolchain/safety/shields/contrib/third_party_shield.py @@ -0,0 +1,28 @@ +import sys +from typing import List + +from models.llama3.datatypes import Message + +parent_dir = "../.." +sys.path.append(parent_dir) +from toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse + +_INSTANCE = None + + +class ThirdPartyShield(ShieldBase): + @staticmethod + def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield": + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = ThirdPartyShield(on_violation_action) + return _INSTANCE + + def __init__( + self, + on_violation_action: OnViolationAction = OnViolationAction.RAISE, + ): + super().__init__(on_violation_action) + + async def run(self, messages: List[Message]) -> ShieldResponse: + super.run() # will raise NotImplementedError diff --git a/toolchain/safety/shields/llama_guard.py b/toolchain/safety/shields/llama_guard.py new file mode 100644 index 000000000..4d8f7c95b --- /dev/null +++ b/toolchain/safety/shields/llama_guard.py @@ -0,0 +1,248 @@ +import re + +from string import Template +from typing import List, Optional + +import torch +from models.llama3.datatypes import Message +from termcolor import cprint +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse +from toolchain.safety.api.datatypes import * # noqa: F403 + +SAFE_RESPONSE = "safe" +_INSTANCE = None + +CAT_VIOLENT_CRIMES = "Violent Crimes" +CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes" +CAT_SEX_CRIMES = "Sex Crimes" +CAT_CHILD_EXPLOITATION = "Child Exploitation" +CAT_DEFAMATION = "Defamation" +CAT_SPECIALIZED_ADVICE = "Specialized Advice" +CAT_PRIVACY = "Privacy" +CAT_INTELLECTUAL_PROPERTY = "Intellectual Property" +CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons" +CAT_HATE = "Hate" +CAT_SELF_HARM = "Self-Harm" +CAT_SEXUAL_CONTENT = "Sexual Content" +CAT_ELECTIONS = "Elections" +CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse" + + +SAFETY_CATEGORIES_TO_CODE_MAP = { + CAT_VIOLENT_CRIMES: "S1", + CAT_NON_VIOLENT_CRIMES: "S2", + CAT_SEX_CRIMES: "S3", + CAT_CHILD_EXPLOITATION: "S4", + CAT_DEFAMATION: "S5", + CAT_SPECIALIZED_ADVICE: "S6", + CAT_PRIVACY: "S7", + CAT_INTELLECTUAL_PROPERTY: "S8", + CAT_INDISCRIMINATE_WEAPONS: "S9", + CAT_HATE: "S10", + CAT_SELF_HARM: "S11", + CAT_SEXUAL_CONTENT: "S12", + CAT_ELECTIONS: "S13", + CAT_CODE_INTERPRETER_ABUSE: "S14", +} + + +DEFAULT_LG_V3_SAFETY_CATEGORIES = [ + CAT_VIOLENT_CRIMES, + CAT_NON_VIOLENT_CRIMES, + CAT_SEX_CRIMES, + CAT_CHILD_EXPLOITATION, + CAT_DEFAMATION, + CAT_SPECIALIZED_ADVICE, + CAT_PRIVACY, + CAT_INTELLECTUAL_PROPERTY, + CAT_INDISCRIMINATE_WEAPONS, + CAT_HATE, + CAT_SELF_HARM, + CAT_SEXUAL_CONTENT, + CAT_ELECTIONS, + CAT_CODE_INTERPRETER_ABUSE, +] + +PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." + +SAFETY_CATEGORIES = """ + + +$categories + +""" + +PROMPT_CONVERSATION = """ + + +$conversations + + +""" + + +PROMPT_INSTRUCTIONS = """ +Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: +- First line must read 'safe' or 'unsafe'. +- If unsafe, a second line must include a comma-separated list of violated categories.""" + + +PROMPT_TEMPLATE = Template( + f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}" +) + + +class LlamaGuardShield(ShieldBase): + + @staticmethod + def instance( + on_violation_action=OnViolationAction.RAISE, + model_dir: str = None, + excluded_categories: List[str] = [], + disable_input_check: bool = False, + disable_output_check: bool = False, + ) -> "LlamaGuardShield": + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = LlamaGuardShield( + on_violation_action, + model_dir, + excluded_categories, + disable_input_check, + disable_output_check, + ) + return _INSTANCE + + def __init__( + self, + on_violation_action: OnViolationAction = OnViolationAction.RAISE, + model_dir: str = None, + excluded_categories: List[str] = [], + disable_input_check: bool = False, + disable_output_check: bool = False, + ): + super().__init__(on_violation_action) + + dtype = torch.bfloat16 + + assert model_dir is not None, "Llama Guard model_dir is None" + + assert len(excluded_categories) == 0 or all( + x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories + ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" + + self.device = "cuda" + self.excluded_categories = excluded_categories + self.disable_input_check = disable_input_check + self.disable_output_check = disable_output_check + + # load model + torch_dtype = torch.bfloat16 + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + self.model = AutoModelForCausalLM.from_pretrained( + model_dir, torch_dtype=torch_dtype, device_map=self.device + ) + + def get_shield_type(self) -> ShieldType: + return BuiltinShield.llama_guard + + def check_unsafe_response(self, response: str) -> Optional[str]: + match = re.match(r"^unsafe\n(.*)$", response) + if match: + # extracts the unsafe code + extracted = match.group(1) + return extracted + + return None + + def get_safety_categories(self) -> List[str]: + + excluded_categories = self.excluded_categories + if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()): + excluded_categories = [] + + categories = [] + for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES: + cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat] + if cat_code in excluded_categories: + continue + categories.append(f"{cat_code}: {cat}.") + + return categories + + def build_prompt(self, messages: List[Message]) -> str: + + categories = self.get_safety_categories() + categories_str = "\n".join(categories) + conversations_str = "\n\n".join( + [f"{m.role.capitalize()}: {m.content}" for m in messages] + ) + return PROMPT_TEMPLATE.substitute( + agent_type=messages[-1].role.capitalize(), + categories=categories_str, + conversations=conversations_str, + ) + + def get_shield_response(self, response: str) -> ShieldResponse: + if response == SAFE_RESPONSE: + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) + unsafe_code = self.check_unsafe_response(response) + if unsafe_code: + unsafe_code_list = unsafe_code.split(",") + if set(unsafe_code_list).issubset(set(self.excluded_categories)): + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, + is_violation=True, + violation_type=unsafe_code, + violation_return_message=CANNED_RESPONSE_TEXT, + ) + + raise ValueError(f"Unexpected response: {response}") + + async def run(self, messages: List[Message]) -> ShieldResponse: + + if self.disable_input_check and messages[-1].role == Role.user.value: + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) + elif self.disable_output_check and messages[-1].role == Role.assistant.value: + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, + is_violation=False, + ) + else: + + prompt = self.build_prompt(messages) + llama_guard_input = { + "role": "user", + "content": prompt, + } + input_ids = self.tokenizer.apply_chat_template( + [llama_guard_input], return_tensors="pt", tokenize=True + ).to(self.device) + prompt_len = input_ids.shape[1] + output = self.model.generate( + input_ids=input_ids, + max_new_tokens=20, + output_scores=True, + return_dict_in_generate=True, + pad_token_id=0, + ) + generated_tokens = output.sequences[:, prompt_len:] + + response = self.tokenizer.decode( + generated_tokens[0], skip_special_tokens=True + ) + + response = response.strip() + shield_response = self.get_shield_response(response) + + cprint(f"Final Llama Guard response {shield_response}", color="magenta") + return shield_response diff --git a/toolchain/safety/shields/prompt_guard.py b/toolchain/safety/shields/prompt_guard.py new file mode 100644 index 000000000..e5082ff8c --- /dev/null +++ b/toolchain/safety/shields/prompt_guard.py @@ -0,0 +1,112 @@ +from enum import auto, Enum +from typing import List + +import torch + +from models.llama3.datatypes import Message +from termcolor import cprint +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield +from toolchain.safety.api.datatypes import * # noqa: F403 + + +class PromptGuardShield(TextShield): + + class Mode(Enum): + INJECTION = auto() + JAILBREAK = auto() + + _instances = {} + _model_cache = None + + @staticmethod + def instance( + model_dir: str, + threshold: float = 0.9, + temperature: float = 1.0, + mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, + on_violation_action=OnViolationAction.RAISE, + ) -> "PromptGuardShield": + action_value = on_violation_action.value + key = (model_dir, threshold, temperature, mode, action_value) + if key not in PromptGuardShield._instances: + PromptGuardShield._instances[key] = PromptGuardShield( + model_dir=model_dir, + threshold=threshold, + temperature=temperature, + mode=mode, + on_violation_action=on_violation_action, + ) + return PromptGuardShield._instances[key] + + def __init__( + self, + model_dir: str, + threshold: float = 0.9, + temperature: float = 1.0, + mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, + on_violation_action: OnViolationAction = OnViolationAction.RAISE, + ): + super().__init__(on_violation_action) + assert ( + model_dir is not None + ), "Must provide a model directory for prompt injection shield" + if temperature <= 0: + raise ValueError("Temperature must be greater than 0") + self.device = "cuda" + if PromptGuardShield._model_cache is None: + # load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_dir) + model = AutoModelForSequenceClassification.from_pretrained( + model_dir, device_map=self.device + ) + PromptGuardShield._model_cache = (tokenizer, model) + + self.tokenizer, self.model = PromptGuardShield._model_cache + self.temperature = temperature + self.threshold = threshold + self.mode = mode + + def get_shield_type(self) -> ShieldType: + return BuiltinShield.prompt_guard + + def convert_messages_to_text(self, messages: List[Message]) -> str: + return message_content_as_str(messages[-1]) + + async def run_impl(self, text: str) -> ShieldResponse: + # run model on messages and return response + inputs = self.tokenizer(text, return_tensors="pt") + inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs[0] + probabilities = torch.softmax(logits / self.temperature, dim=-1) + score_embedded = probabilities[0, 1].item() + score_malicious = probabilities[0, 2].item() + cprint( + f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", + color="magenta", + ) + + if self.mode == self.Mode.INJECTION and ( + score_embedded + score_malicious > self.threshold + ): + return ShieldResponse( + shield_type=BuiltinShield.prompt_guard, + is_violation=True, + violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", + violation_return_message="Sorry, I cannot do this.", + ) + elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: + return ShieldResponse( + shield_type=BuiltinShield.prompt_guard, + is_violation=True, + violation_type=f"prompt_injection:malicious={score_malicious}", + violation_return_message="Sorry, I cannot do this.", + ) + + return ShieldResponse( + shield_type=BuiltinShield.prompt_guard, + is_violation=False, + ) diff --git a/toolchain/safety/shields/shield_runner.py b/toolchain/safety/shields/shield_runner.py new file mode 100644 index 000000000..cb0c23302 --- /dev/null +++ b/toolchain/safety/shields/shield_runner.py @@ -0,0 +1,46 @@ +import asyncio +from typing import List + +from models.llama3.datatypes import Message, Role + +from .base import OnViolationAction, ShieldBase, ShieldResponse + + +class SafetyException(Exception): + def __init__(self, response: ShieldResponse): + self.response = response + super().__init__(response.violation_return_message) + + +class ShieldRunnerMixin: + + def __init__( + self, + input_shields: List[ShieldBase] = None, + output_shields: List[ShieldBase] = None, + ): + self.input_shields = input_shields + self.output_shields = output_shields + + async def run_shields( + self, messages: List[Message], shields: List[ShieldBase] + ) -> List[ShieldResponse]: + # some shields like llama-guard require the first message to be a user message + # since this might be a tool call, first role might not be user + if len(messages) > 0 and messages[0].role != Role.user.value: + # TODO(ashwin): we need to change the type of the message, this kind of modification + # is no longer appropriate + messages[0].role = Role.user.value + + results = await asyncio.gather(*[s.run(messages) for s in shields]) + for shield, r in zip(shields, results): + if r.is_violation: + if shield.on_violation_action == OnViolationAction.RAISE: + raise SafetyException(r) + elif shield.on_violation_action == OnViolationAction.WARN: + cprint( + f"[Warn]{shield.__class__.__name__} raised a warning", + color="red", + ) + + return results diff --git a/toolchain/spec/generate.py b/toolchain/spec/generate.py new file mode 100644 index 000000000..8afb77cf1 --- /dev/null +++ b/toolchain/spec/generate.py @@ -0,0 +1,54 @@ +from datetime import datetime + +import yaml + +from pyopenapi import Info, Options, Server, Specification + +from models.llama3.datatypes import * # noqa: F403 +from toolchain.dataset.api import * # noqa: F403 +from toolchain.evaluations.api import * # noqa: F403 +from toolchain.inference.api import * # noqa: F403 +from toolchain.memory.api import * # noqa: F403 +from toolchain.post_training.api import * # noqa: F403 +from toolchain.reward_scoring.api import * # noqa: F403 +from toolchain.synthetic_data_generation.api import * # noqa: F403 +from agentic_system.api import * # noqa: F403 + + +class LlamaStackEndpoints( + ModelInference, + AgenticSystem, + RewardScoring, + SyntheticDataGeneration, + Datasets, + PostTraining, + MemoryBanks, + Evaluations, +): ... + + +if __name__ == "__main__": + now = str(datetime.now()) + print( + "Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now + ) + spec = Specification( + LlamaStackEndpoints, + Options( + server=Server(url="http://any-hosted-llama-stack.com"), + info=Info( + title="[DRAFT] Llama Stack Specification", + version="0.0.1", + description="""This is the specification of the llama stack that provides + a set of endpoints and their corresponding interfaces that are tailored to + best leverage Llama Models. The specification is still in draft and subject to change. + Generated at """ + + now, + ), + ), + ) + with open("openapi.yaml", "w", encoding="utf-8") as fp: + yaml.dump(spec.get_json(), fp, allow_unicode=True) + + with open("openapi.html", "w") as fp: + spec.write_html(fp, pretty_print=True) diff --git a/toolchain/spec/openapi.html b/toolchain/spec/openapi.html new file mode 100644 index 000000000..b09bf6c48 --- /dev/null +++ b/toolchain/spec/openapi.html @@ -0,0 +1,4584 @@ + + + + + + + OpenAPI specification + + + + + + + +
+ + + diff --git a/toolchain/spec/openapi.yaml b/toolchain/spec/openapi.yaml new file mode 100644 index 000000000..06f735cc5 --- /dev/null +++ b/toolchain/spec/openapi.yaml @@ -0,0 +1,2894 @@ +components: + responses: {} + schemas: + AgenticSystemCreateRequest: + additionalProperties: false + properties: + instance_config: + $ref: '#/components/schemas/AgenticSystemInstanceConfig' + model: + $ref: '#/components/schemas/InstructModel' + required: + - model + - instance_config + type: object + AgenticSystemCreateResponse: + additionalProperties: false + properties: + system_id: + type: string + required: + - system_id + type: object + AgenticSystemInstanceConfig: + additionalProperties: false + properties: + available_tools: + items: + $ref: '#/components/schemas/AgenticSystemToolDefinition' + type: array + debug_prefix_messages: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + input_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + instructions: + type: string + output_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + quantization_config: + oneOf: + - $ref: '#/components/schemas/Bf16QuantizationConfig' + - $ref: '#/components/schemas/Fp8QuantizationConfig' + sampling_params: + $ref: '#/components/schemas/SamplingParams' + required: + - instructions + type: object + AgenticSystemSessionCreateRequest: + additionalProperties: false + properties: + session_name: + type: string + system_id: + type: string + required: + - system_id + - session_name + type: object + AgenticSystemSessionCreateResponse: + additionalProperties: false + properties: + session_id: + type: string + required: + - session_id + type: object + AgenticSystemToolDefinition: + additionalProperties: false + properties: + description: + type: string + execution_config: + $ref: '#/components/schemas/RestAPIExecutionConfig' + input_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + output_shields: + items: + $ref: '#/components/schemas/ShieldDefinition' + type: array + parameters: + additionalProperties: + $ref: '#/components/schemas/ToolParamDefinition' + type: object + tool_name: + oneOf: + - enum: + - brave_search + - wolfram_alpha + - photogen + - code_interpreter + type: string + - type: string + required: + - tool_name + type: object + AgenticSystemTurnCreateRequest: + additionalProperties: false + properties: + messages: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + type: array + override_config: + $ref: '#/components/schemas/AgenticSystemInstanceConfig' + session_id: + type: string + stream: + type: boolean + system_id: + type: string + required: + - system_id + - session_id + - messages + type: object + AgenticSystemTurnResponseEvent: + additionalProperties: false + properties: + payload: + oneOf: + - $ref: '#/components/schemas/AgenticSystemTurnResponseStepStartPayload' + - $ref: '#/components/schemas/AgenticSystemTurnResponseStepProgressPayload' + - $ref: '#/components/schemas/AgenticSystemTurnResponseStepCompletePayload' + - $ref: '#/components/schemas/AgenticSystemTurnResponseTurnStartPayload' + - $ref: '#/components/schemas/AgenticSystemTurnResponseTurnCompletePayload' + required: + - payload + title: Streamed agent execution response. + type: object + AgenticSystemTurnResponseStepCompletePayload: + additionalProperties: false + properties: + event_type: + const: step_complete + type: string + step_details: + oneOf: + - $ref: '#/components/schemas/ModelInferenceStep' + - $ref: '#/components/schemas/ToolExecutionStep' + - $ref: '#/components/schemas/ShieldCallStep' + - $ref: '#/components/schemas/MemoryRetrievalStep' + step_type: + enum: + - model_inference + - tool_execution + - shield_call + - memory_retrieval + type: string + required: + - event_type + - step_type + - step_details + type: object + AgenticSystemTurnResponseStepProgressPayload: + additionalProperties: false + properties: + event_type: + const: step_progress + type: string + model_response_text_delta: + type: string + step_id: + type: string + step_type: + enum: + - model_inference + - tool_execution + - shield_call + - memory_retrieval + type: string + tool_call_delta: + $ref: '#/components/schemas/ToolCallDelta' + tool_response_text_delta: + type: string + required: + - event_type + - step_type + - step_id + type: object + AgenticSystemTurnResponseStepStartPayload: + additionalProperties: false + properties: + event_type: + const: step_start + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + step_id: + type: string + step_type: + enum: + - model_inference + - tool_execution + - shield_call + - memory_retrieval + type: string + required: + - event_type + - step_type + - step_id + type: object + AgenticSystemTurnResponseStreamChunk: + additionalProperties: false + properties: + event: + $ref: '#/components/schemas/AgenticSystemTurnResponseEvent' + required: + - event + type: object + AgenticSystemTurnResponseTurnCompletePayload: + additionalProperties: false + properties: + event_type: + const: turn_complete + type: string + turn: + $ref: '#/components/schemas/Turn' + required: + - event_type + - turn + type: object + AgenticSystemTurnResponseTurnStartPayload: + additionalProperties: false + properties: + event_type: + const: turn_start + type: string + turn_id: + type: string + required: + - event_type + - turn_id + type: object + Attachment: + additionalProperties: false + properties: + mime_type: + type: string + url: + $ref: '#/components/schemas/URL' + required: + - url + - mime_type + type: object + BatchChatCompletionRequest: + additionalProperties: false + properties: + available_tools: + items: + $ref: '#/components/schemas/ToolDefinition' + type: array + logprobs: + additionalProperties: false + properties: + top_k: + type: integer + type: object + messages_batch: + items: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + type: array + model: + $ref: '#/components/schemas/InstructModel' + quantization_config: + oneOf: + - $ref: '#/components/schemas/Bf16QuantizationConfig' + - $ref: '#/components/schemas/Fp8QuantizationConfig' + sampling_params: + $ref: '#/components/schemas/SamplingParams' + required: + - model + - messages_batch + type: object + BatchCompletionRequest: + additionalProperties: false + properties: + content_batch: + items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + - items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + type: array + type: array + logprobs: + additionalProperties: false + properties: + top_k: + type: integer + type: object + model: + $ref: '#/components/schemas/PretrainedModel' + quantization_config: + oneOf: + - $ref: '#/components/schemas/Bf16QuantizationConfig' + - $ref: '#/components/schemas/Fp8QuantizationConfig' + sampling_params: + $ref: '#/components/schemas/SamplingParams' + required: + - model + - content_batch + type: object + Bf16QuantizationConfig: + additionalProperties: false + properties: + quantization_type: + const: bf16 + type: string + required: + - quantization_type + type: object + BuiltinShield: + enum: + - llama_guard + - prompt_guard + - code_scanner_guard + - third_party_shield + type: string + ChatCompletionRequest: + additionalProperties: false + properties: + available_tools: + items: + $ref: '#/components/schemas/ToolDefinition' + type: array + logprobs: + additionalProperties: false + properties: + top_k: + type: integer + type: object + messages: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + model: + $ref: '#/components/schemas/InstructModel' + quantization_config: + oneOf: + - $ref: '#/components/schemas/Bf16QuantizationConfig' + - $ref: '#/components/schemas/Fp8QuantizationConfig' + sampling_params: + $ref: '#/components/schemas/SamplingParams' + stream: + type: boolean + required: + - model + - messages + type: object + ChatCompletionResponse: + additionalProperties: false + properties: + completion_message: + $ref: '#/components/schemas/CompletionMessage' + logprobs: + items: + $ref: '#/components/schemas/TokenLogProbs' + type: array + required: + - completion_message + type: object + ChatCompletionResponseEvent: + additionalProperties: false + properties: + delta: + oneOf: + - type: string + - $ref: '#/components/schemas/ToolCallDelta' + event_type: + $ref: '#/components/schemas/ChatCompletionResponseEventType' + logprobs: + items: + $ref: '#/components/schemas/TokenLogProbs' + type: array + stop_reason: + $ref: '#/components/schemas/StopReason' + required: + - event_type + - delta + title: Chat completion response event. + type: object + ChatCompletionResponseEventType: + enum: + - start + - complete + - progress + type: string + ChatCompletionResponseStreamChunk: + additionalProperties: false + properties: + event: + $ref: '#/components/schemas/ChatCompletionResponseEvent' + required: + - event + title: SSE-stream of these events. + type: object + CompletionMessage: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + - items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + type: array + role: + const: assistant + type: string + stop_reason: + $ref: '#/components/schemas/StopReason' + tool_calls: + items: + $ref: '#/components/schemas/ToolCall' + type: array + required: + - role + - content + - stop_reason + - tool_calls + type: object + CompletionRequest: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + - items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + type: array + logprobs: + additionalProperties: false + properties: + top_k: + type: integer + type: object + model: + $ref: '#/components/schemas/PretrainedModel' + quantization_config: + oneOf: + - $ref: '#/components/schemas/Bf16QuantizationConfig' + - $ref: '#/components/schemas/Fp8QuantizationConfig' + sampling_params: + $ref: '#/components/schemas/SamplingParams' + stream: + type: boolean + required: + - model + - content + type: object + CompletionResponse: + additionalProperties: false + properties: + completion_message: + $ref: '#/components/schemas/CompletionMessage' + logprobs: + items: + $ref: '#/components/schemas/TokenLogProbs' + type: array + required: + - completion_message + type: object + CompletionResponseStreamChunk: + additionalProperties: false + properties: + delta: + type: string + logprobs: + items: + $ref: '#/components/schemas/TokenLogProbs' + type: array + stop_reason: + $ref: '#/components/schemas/StopReason' + required: + - delta + title: streamed completion response. + type: object + CreateDatasetRequest: + additionalProperties: false + properties: + dataset: + $ref: '#/components/schemas/TrainEvalDataset' + uuid: + type: string + required: + - uuid + - dataset + title: Request to create a dataset. + type: object + DPOAlignmentConfig: + additionalProperties: false + properties: + epsilon: + type: number + gamma: + type: number + reward_clip: + type: number + reward_scale: + type: number + required: + - reward_scale + - reward_clip + - epsilon + - gamma + type: object + DialogGenerations: + additionalProperties: false + properties: + dialog: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + sampled_generations: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + required: + - dialog + - sampled_generations + type: object + DoraFinetuningConfig: + additionalProperties: false + properties: + alpha: + type: integer + apply_lora_to_mlp: + type: boolean + apply_lora_to_output: + type: boolean + lora_attn_modules: + items: + type: string + type: array + rank: + type: integer + required: + - lora_attn_modules + - apply_lora_to_mlp + - apply_lora_to_output + - rank + - alpha + type: object + EvaluateQuestionAnsweringRequest: + additionalProperties: false + properties: + checkpoint: + additionalProperties: false + properties: + iters: + type: integer + path: + $ref: '#/components/schemas/URL' + required: + - iters + - path + type: object + dataset: + $ref: '#/components/schemas/TrainEvalDataset' + job_uuid: + type: string + metrics: + items: + enum: + - em + - f1 + type: string + type: array + sampling_params: + $ref: '#/components/schemas/SamplingParams' + required: + - job_uuid + - dataset + - checkpoint + - sampling_params + - metrics + title: Request to evaluate question answering. + type: object + EvaluateSummarizationRequest: + additionalProperties: false + properties: + checkpoint: + additionalProperties: false + properties: + iters: + type: integer + path: + $ref: '#/components/schemas/URL' + required: + - iters + - path + type: object + dataset: + $ref: '#/components/schemas/TrainEvalDataset' + job_uuid: + type: string + metrics: + items: + enum: + - rouge + - bleu + type: string + type: array + sampling_params: + $ref: '#/components/schemas/SamplingParams' + required: + - job_uuid + - dataset + - checkpoint + - sampling_params + - metrics + title: Request to evaluate summarization. + type: object + EvaluateTextGenerationRequest: + additionalProperties: false + properties: + checkpoint: + additionalProperties: false + properties: + iters: + type: integer + path: + $ref: '#/components/schemas/URL' + required: + - iters + - path + type: object + dataset: + $ref: '#/components/schemas/TrainEvalDataset' + job_uuid: + type: string + metrics: + items: + enum: + - perplexity + - rouge + - bleu + type: string + type: array + sampling_params: + $ref: '#/components/schemas/SamplingParams' + required: + - job_uuid + - dataset + - checkpoint + - sampling_params + - metrics + title: Request to evaluate text generation. + type: object + EvaluationJob: + additionalProperties: false + properties: + job_uuid: + type: string + required: + - job_uuid + type: object + EvaluationJobArtifactsResponse: + additionalProperties: false + properties: + job_uuid: + type: string + required: + - job_uuid + title: Artifacts of a evaluation job. + type: object + EvaluationJobLogStream: + additionalProperties: false + properties: + job_uuid: + type: string + required: + - job_uuid + type: object + EvaluationJobStatusResponse: + additionalProperties: false + properties: + job_uuid: + type: string + required: + - job_uuid + type: object + FinetuningAlgorithm: + enum: + - full + - lora + - qlora + - dora + type: string + Fp8QuantizationConfig: + additionalProperties: false + properties: + quantization_type: + const: fp8 + type: string + required: + - quantization_type + type: object + InstructModel: + enum: + - llama3_8b_chat + - llama3_70b_chat + type: string + LoraFinetuningConfig: + additionalProperties: false + properties: + alpha: + type: integer + apply_lora_to_mlp: + type: boolean + apply_lora_to_output: + type: boolean + lora_attn_modules: + items: + type: string + type: array + rank: + type: integer + required: + - lora_attn_modules + - apply_lora_to_mlp + - apply_lora_to_output + - rank + - alpha + type: object + MemoryBank: + additionalProperties: false + properties: + memory_bank_id: + type: string + memory_bank_name: + type: string + required: + - memory_bank_id + - memory_bank_name + type: object + MemoryBankDocument: + additionalProperties: false + properties: + content: + contentEncoding: base64 + type: string + document_id: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + mime_type: + type: string + required: + - document_id + - content + - metadata + - mime_type + type: object + MemoryRetrievalStep: + additionalProperties: false + properties: + completed_at: + format: date-time + type: string + documents: + items: + $ref: '#/components/schemas/MemoryBankDocument' + type: array + memory_bank_ids: + items: + type: string + type: array + scores: + items: + type: number + type: array + started_at: + format: date-time + type: string + step_id: + type: string + step_type: + const: memory_retrieval + type: string + turn_id: + type: string + required: + - turn_id + - step_id + - step_type + - memory_bank_ids + - documents + - scores + type: object + ModelInferenceStep: + additionalProperties: false + properties: + completed_at: + format: date-time + type: string + model_response: + $ref: '#/components/schemas/CompletionMessage' + started_at: + format: date-time + type: string + step_id: + type: string + step_type: + const: model_inference + type: string + turn_id: + type: string + required: + - turn_id + - step_id + - step_type + - model_response + type: object + OnViolationAction: + enum: + - 0 + - 1 + - 2 + type: integer + OptimizerConfig: + additionalProperties: false + properties: + lr: + type: number + lr_min: + type: number + optimizer_type: + enum: + - adam + - adamw + - sgd + type: string + weight_decay: + type: number + required: + - optimizer_type + - lr + - lr_min + - weight_decay + type: object + PostTrainingJob: + additionalProperties: false + properties: + job_uuid: + type: string + required: + - job_uuid + type: object + PostTrainingJobArtifactsResponse: + additionalProperties: false + properties: + checkpoints: + items: + additionalProperties: false + properties: + iters: + type: integer + path: + $ref: '#/components/schemas/URL' + required: + - iters + - path + type: object + type: array + job_uuid: + type: string + required: + - job_uuid + - checkpoints + title: Artifacts of a finetuning job. + type: object + PostTrainingJobLogStream: + additionalProperties: false + properties: + job_uuid: + type: string + log_lines: + items: + type: string + type: array + required: + - job_uuid + - log_lines + title: Stream of logs from a finetuning job. + type: object + PostTrainingJobStatus: + enum: + - running + - completed + - failed + - scheduled + type: string + PostTrainingJobStatusResponse: + additionalProperties: false + properties: + checkpoints: + items: + additionalProperties: false + properties: + iters: + type: integer + path: + $ref: '#/components/schemas/URL' + required: + - iters + - path + type: object + type: array + completed_at: + format: date-time + type: string + job_uuid: + type: string + resources_allocated: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + scheduled_at: + format: date-time + type: string + started_at: + format: date-time + type: string + status: + $ref: '#/components/schemas/PostTrainingJobStatus' + required: + - job_uuid + - status + - checkpoints + title: Status of a finetuning job. + type: object + PostTrainingRLHFRequest: + additionalProperties: false + properties: + algorithm: + $ref: '#/components/schemas/RLHFAlgorithm' + algorithm_config: + $ref: '#/components/schemas/DPOAlignmentConfig' + dataset: + $ref: '#/components/schemas/TrainEvalDataset' + finetuned_model: + $ref: '#/components/schemas/URL' + hyperparam_search_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + job_uuid: + type: string + logger_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + optimizer_config: + $ref: '#/components/schemas/OptimizerConfig' + training_config: + $ref: '#/components/schemas/TrainingConfig' + validation_dataset: + $ref: '#/components/schemas/TrainEvalDataset' + required: + - job_uuid + - finetuned_model + - dataset + - validation_dataset + - algorithm + - algorithm_config + - optimizer_config + - training_config + - hyperparam_search_config + - logger_config + title: Request to finetune a model. + type: object + PostTrainingSFTRequest: + additionalProperties: false + properties: + algorithm: + $ref: '#/components/schemas/FinetuningAlgorithm' + algorithm_config: + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QLoraFinetuningConfig' + - $ref: '#/components/schemas/DoraFinetuningConfig' + dataset: + $ref: '#/components/schemas/TrainEvalDataset' + hyperparam_search_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + job_uuid: + type: string + logger_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + model: + $ref: '#/components/schemas/PretrainedModel' + optimizer_config: + $ref: '#/components/schemas/OptimizerConfig' + training_config: + $ref: '#/components/schemas/TrainingConfig' + validation_dataset: + $ref: '#/components/schemas/TrainEvalDataset' + required: + - job_uuid + - model + - dataset + - validation_dataset + - algorithm + - algorithm_config + - optimizer_config + - training_config + - hyperparam_search_config + - logger_config + title: Request to finetune a model. + type: object + PretrainedModel: + enum: + - llama3_8b + - llama3_70b + type: string + QLoraFinetuningConfig: + additionalProperties: false + properties: + alpha: + type: integer + apply_lora_to_mlp: + type: boolean + apply_lora_to_output: + type: boolean + lora_attn_modules: + items: + type: string + type: array + rank: + type: integer + required: + - lora_attn_modules + - apply_lora_to_mlp + - apply_lora_to_output + - rank + - alpha + type: object + RLHFAlgorithm: + enum: + - dpo + type: string + RestAPIExecutionConfig: + additionalProperties: false + properties: + body: + additionalProperties: + type: string + type: object + headers: + additionalProperties: + type: string + type: object + method: + $ref: '#/components/schemas/RestAPIMethod' + params: + additionalProperties: + type: string + type: object + url: + $ref: '#/components/schemas/URL' + required: + - url + - method + type: object + RestAPIMethod: + enum: + - GET + - POST + - PUT + - DELETE + type: string + RewardModel: + enum: + - llama3_70b_reward + - llama3_405b_reward + type: string + RewardScoringRequest: + additionalProperties: false + properties: + dialog_generations: + items: + $ref: '#/components/schemas/DialogGenerations' + type: array + model: + $ref: '#/components/schemas/RewardModel' + required: + - dialog_generations + - model + title: Request to score a reward function. A list of prompts and a list of responses + per prompt. + type: object + RewardScoringResponse: + additionalProperties: false + properties: + scored_generations: + items: + $ref: '#/components/schemas/ScoredDialogGenerations' + type: array + required: + - scored_generations + title: Response from the reward scoring. Batch of (prompt, response, score) + tuples that pass the threshold. + type: object + SamplingParams: + additionalProperties: false + properties: + max_tokens: + type: integer + repetition_penalty: + type: number + strategy: + $ref: '#/components/schemas/SamplingStrategy' + temperature: + type: number + top_k: + type: integer + top_p: + type: number + required: + - strategy + type: object + SamplingStrategy: + enum: + - greedy + - top_p + - top_k + type: string + ScoredDialogGenerations: + additionalProperties: false + properties: + dialog: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + scored_generations: + items: + $ref: '#/components/schemas/ScoredMessage' + type: array + required: + - dialog + - scored_generations + type: object + ScoredMessage: + additionalProperties: false + properties: + message: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + score: + type: number + required: + - message + - score + type: object + Session: + additionalProperties: false + properties: + session_id: + type: string + session_name: + type: string + started_at: + format: date-time + type: string + turns: + items: + $ref: '#/components/schemas/Turn' + type: array + required: + - session_id + - session_name + - turns + - started_at + title: A single session of an interaction with an Agentic System. + type: object + ShieldCallStep: + additionalProperties: false + properties: + completed_at: + format: date-time + type: string + response: + $ref: '#/components/schemas/ShieldResponse' + started_at: + format: date-time + type: string + step_id: + type: string + step_type: + const: shield_call + type: string + turn_id: + type: string + required: + - turn_id + - step_id + - step_type + - response + type: object + ShieldDefinition: + additionalProperties: false + properties: + description: + type: string + execution_config: + $ref: '#/components/schemas/RestAPIExecutionConfig' + on_violation_action: + $ref: '#/components/schemas/OnViolationAction' + parameters: + additionalProperties: + $ref: '#/components/schemas/ToolParamDefinition' + type: object + shield_type: + oneOf: + - $ref: '#/components/schemas/BuiltinShield' + - type: string + required: + - shield_type + - on_violation_action + type: object + ShieldResponse: + additionalProperties: false + properties: + is_violation: + type: boolean + shield_type: + oneOf: + - $ref: '#/components/schemas/BuiltinShield' + - type: string + violation_return_message: + type: string + violation_type: + type: string + required: + - shield_type + - is_violation + type: object + StopReason: + enum: + - end_of_turn + - end_of_message + - out_of_tokens + type: string + SyntheticDataGenerationRequest: + additionalProperties: false + properties: + dialogs: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + type: array + filtering_function: + enum: + - none + - random + - top_k + - top_p + - top_k_top_p + - sigmoid + title: The type of filtering function. + type: string + model: + $ref: '#/components/schemas/RewardModel' + required: + - dialogs + - filtering_function + title: Request to generate synthetic data. A small batch of prompts and a filtering + function + type: object + SyntheticDataGenerationResponse: + additionalProperties: false + properties: + statistics: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + synthetic_data: + items: + $ref: '#/components/schemas/ScoredDialogGenerations' + type: array + required: + - synthetic_data + title: Response from the synthetic data generation. Batch of (prompt, response, + score) tuples that pass the threshold. + type: object + SystemMessage: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + - items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + type: array + role: + const: system + type: string + required: + - role + - content + type: object + TokenLogProbs: + additionalProperties: false + properties: + logprobs_by_token: + additionalProperties: + type: number + type: object + required: + - logprobs_by_token + type: object + ToolCall: + additionalProperties: false + properties: + arguments: + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - items: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + type: array + - additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + type: object + type: object + call_id: + type: string + tool_name: + oneOf: + - enum: + - brave_search + - wolfram_alpha + - photogen + - code_interpreter + type: string + - type: string + required: + - call_id + - tool_name + - arguments + type: object + ToolCallDelta: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/ToolCall' + parse_status: + $ref: '#/components/schemas/ToolCallParseStatus' + required: + - content + - parse_status + type: object + ToolCallParseStatus: + enum: + - start + - in_progress + - failure + - success + type: string + ToolDefinition: + additionalProperties: false + properties: + description: + type: string + parameters: + additionalProperties: + $ref: '#/components/schemas/ToolParamDefinition' + type: object + tool_name: + oneOf: + - enum: + - brave_search + - wolfram_alpha + - photogen + - code_interpreter + type: string + - type: string + required: + - tool_name + type: object + ToolExecutionStep: + additionalProperties: false + properties: + completed_at: + format: date-time + type: string + started_at: + format: date-time + type: string + step_id: + type: string + step_type: + const: tool_execution + type: string + tool_calls: + items: + $ref: '#/components/schemas/ToolCall' + type: array + tool_responses: + items: + $ref: '#/components/schemas/ToolResponse' + type: array + turn_id: + type: string + required: + - turn_id + - step_id + - step_type + - tool_calls + - tool_responses + type: object + ToolParamDefinition: + additionalProperties: false + properties: + description: + type: string + param_type: + type: string + required: + type: boolean + required: + - param_type + type: object + ToolResponse: + additionalProperties: false + properties: + call_id: + type: string + content: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + - items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + type: array + tool_name: + oneOf: + - enum: + - brave_search + - wolfram_alpha + - photogen + - code_interpreter + type: string + - type: string + required: + - call_id + - tool_name + - content + type: object + ToolResponseMessage: + additionalProperties: false + properties: + call_id: + type: string + content: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + - items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + type: array + role: + const: ipython + type: string + tool_name: + oneOf: + - enum: + - brave_search + - wolfram_alpha + - photogen + - code_interpreter + type: string + - type: string + required: + - role + - call_id + - tool_name + - content + type: object + TrainEvalDataset: + additionalProperties: false + properties: + columns: + additionalProperties: + $ref: '#/components/schemas/TrainEvalDatasetColumnType' + type: object + content_url: + $ref: '#/components/schemas/URL' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + required: + - columns + - content_url + title: Dataset to be used for training or evaluating language models. + type: object + TrainEvalDatasetColumnType: + enum: + - dialog + - text + - media + - number + - json + type: string + TrainingConfig: + additionalProperties: false + properties: + batch_size: + type: integer + enable_activation_checkpointing: + type: boolean + fsdp_cpu_offload: + type: boolean + memory_efficient_fsdp_wrap: + type: boolean + n_epochs: + type: integer + n_iters: + type: integer + shuffle: + type: boolean + required: + - n_epochs + - batch_size + - shuffle + - n_iters + - enable_activation_checkpointing + - memory_efficient_fsdp_wrap + - fsdp_cpu_offload + type: object + Turn: + additionalProperties: false + properties: + completed_at: + format: date-time + type: string + input_messages: + items: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + type: array + output_message: + $ref: '#/components/schemas/CompletionMessage' + session_id: + type: string + started_at: + format: date-time + type: string + steps: + items: + oneOf: + - $ref: '#/components/schemas/ModelInferenceStep' + - $ref: '#/components/schemas/ToolExecutionStep' + - $ref: '#/components/schemas/ShieldCallStep' + - $ref: '#/components/schemas/MemoryRetrievalStep' + type: array + turn_id: + type: string + required: + - turn_id + - session_id + - input_messages + - steps + - output_message + - started_at + title: A single turn in an interaction with an Agentic System. + type: object + URL: + format: uri + pattern: ^(https?://|file://|data:) + type: string + UserMessage: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + - items: + oneOf: + - type: string + - $ref: '#/components/schemas/Attachment' + type: array + role: + const: user + type: string + required: + - role + - content + type: object +info: + description: "This is the specification of the llama stack that provides\n \ + \ a set of endpoints and their corresponding interfaces that are tailored\ + \ to\n best leverage Llama Models. The specification is still in\ + \ draft and subject to change.\n Generated at 2024-07-19 11:49:56.794897" + title: '[DRAFT] Llama Stack Specification' + version: 0.0.1 +jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema +openapi: 3.1.0 +paths: + /agentic_system/create: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AgenticSystemCreateRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/AgenticSystemCreateResponse' + description: OK + tags: + - AgenticSystem + /agentic_system/delete: + delete: + parameters: + - in: query + name: agent_id + required: true + schema: + type: string + responses: + '200': + description: OK + tags: + - AgenticSystem + /agentic_system/memory_bank/attach: + post: + parameters: + - in: query + name: agent_id + required: true + schema: + type: string + - in: query + name: session_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + items: + type: string + type: array + required: true + responses: + '200': + description: OK + tags: + - AgenticSystem + /agentic_system/memory_bank/detach: + post: + parameters: + - in: query + name: agent_id + required: true + schema: + type: string + - in: query + name: session_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + items: + type: string + type: array + required: true + responses: + '200': + description: OK + tags: + - AgenticSystem + /agentic_system/session/create: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AgenticSystemSessionCreateRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/AgenticSystemSessionCreateResponse' + description: OK + tags: + - AgenticSystem + /agentic_system/session/get: + post: + parameters: + - in: query + name: agent_id + required: true + schema: + type: string + - in: query + name: session_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + oneOf: + - items: + type: string + type: array + - type: 'null' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Session' + description: OK + tags: + - AgenticSystem + /agentic_system/turn/create: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AgenticSystemTurnCreateRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/AgenticSystemTurnResponseStreamChunk' + description: OK + tags: + - AgenticSystem + /agentic_system/turn/get: + get: + parameters: + - in: query + name: agent_id + required: true + schema: + type: string + - in: query + name: turn_id + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Turn' + description: OK + tags: + - AgenticSystem + /datasets/create: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateDatasetRequest' + required: true + responses: + '200': + description: OK + tags: + - Datasets + /datasets/delete: + delete: + parameters: + - in: query + name: dataset_uuid + required: true + schema: + type: string + responses: + '200': + description: OK + tags: + - Datasets + /datasets/get: + get: + parameters: + - in: query + name: dataset_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/TrainEvalDataset' + description: OK + tags: + - Datasets + /evaluate/job/artifacts: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluationJobArtifactsResponse' + description: OK + tags: + - Evaluations + /evaluate/job/cancel: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + description: OK + tags: + - Evaluations + /evaluate/job/logs: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluationJobLogStream' + description: OK + tags: + - Evaluations + /evaluate/job/status: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluationJobStatusResponse' + description: OK + tags: + - Evaluations + /evaluate/jobs: + get: + parameters: [] + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/EvaluationJob' + description: OK + tags: + - Evaluations + /evaluate/question_answering/: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluateQuestionAnsweringRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluationJob' + description: OK + tags: + - Evaluations + /evaluate/summarization/: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluateSummarizationRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluationJob' + description: OK + tags: + - Evaluations + /evaluate/text_generation/: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluateTextGenerationRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluationJob' + description: OK + tags: + - Evaluations + /inference/batch_chat_completion: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/BatchChatCompletionRequest' + required: true + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/ChatCompletionResponse' + description: OK + tags: + - ModelInference + /inference/batch_completion: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/BatchCompletionRequest' + required: true + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/CompletionResponse' + description: OK + tags: + - ModelInference + /inference/chat_completion: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ChatCompletionRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ChatCompletionResponseStreamChunk' + description: SSE-stream of these events. + tags: + - ModelInference + /inference/completion: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CompletionRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/CompletionResponseStreamChunk' + description: streamed completion response. + tags: + - ModelInference + /memory_bank/delete: + post: + parameters: + - in: query + name: bank_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + items: + type: string + type: array + required: true + responses: + '200': + content: + application/jsonl: + schema: + type: string + description: OK + tags: + - MemoryBanks + /memory_bank/get: + post: + parameters: + - in: query + name: bank_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + items: + type: string + type: array + required: true + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/MemoryBankDocument' + description: OK + tags: + - MemoryBanks + /memory_bank/insert: + post: + parameters: + - in: query + name: bank_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + items: + $ref: '#/components/schemas/MemoryBankDocument' + type: array + required: true + responses: + '200': + description: OK + tags: + - MemoryBanks + /memory_bank/update: + post: + parameters: + - in: query + name: bank_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + items: + $ref: '#/components/schemas/MemoryBankDocument' + type: array + required: true + responses: + '200': + description: OK + tags: + - MemoryBanks + /memory_banks/create: + post: + parameters: + - in: query + name: bank_id + required: true + schema: + type: string + - in: query + name: bank_name + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + items: + $ref: '#/components/schemas/MemoryBankDocument' + type: array + required: true + responses: + '200': + description: OK + tags: + - MemoryBanks + /memory_banks/drop: + delete: + parameters: + - in: query + name: bank_id + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + type: string + description: OK + tags: + - MemoryBanks + /memory_banks/get: + get: + parameters: + - in: query + name: bank_id + required: true + schema: + type: string + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/MemoryBank' + description: OK + tags: + - MemoryBanks + /memory_banks/list: + get: + parameters: [] + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/MemoryBank' + description: OK + tags: + - MemoryBanks + /post_training/job/artifacts: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' + description: OK + tags: + - PostTraining + /post_training/job/cancel: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + description: OK + tags: + - PostTraining + /post_training/job/logs: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJobLogStream' + description: OK + tags: + - PostTraining + /post_training/job/status: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJobStatusResponse' + description: OK + tags: + - PostTraining + /post_training/jobs: + get: + parameters: [] + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/PostTrainingJob' + description: OK + tags: + - PostTraining + /post_training/preference_optimize: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingRLHFRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJob' + description: OK + tags: + - PostTraining + /post_training/supervised_fine_tune: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingSFTRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJob' + description: OK + tags: + - PostTraining + /reward_scoring/score: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RewardScoringRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RewardScoringResponse' + description: OK + tags: + - RewardScoring + /synthetic_data_generation/generate: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/SyntheticDataGenerationRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/SyntheticDataGenerationResponse' + description: OK + tags: + - SyntheticDataGeneration +security: +- Default: [] +servers: +- url: http://any-hosted-llama-stack.com +tags: +- name: AgenticSystem +- name: Datasets +- name: ModelInference +- name: SyntheticDataGeneration +- name: MemoryBanks +- name: PostTraining +- name: Evaluations +- name: RewardScoring +- description: + name: Attachment +- description: + name: BatchChatCompletionRequest +- description: + name: Bf16QuantizationConfig +- description: + name: CompletionMessage +- description: + name: Fp8QuantizationConfig +- description: + name: InstructModel +- description: + name: SamplingParams +- description: + name: SamplingStrategy +- description: + name: StopReason +- description: + name: SystemMessage +- description: + name: ToolCall +- description: + name: ToolDefinition +- description: + name: ToolParamDefinition +- description: + name: ToolResponseMessage +- description: + name: URL +- description: + name: UserMessage +- description: + name: ChatCompletionResponse +- description: + name: TokenLogProbs +- description: + name: BatchCompletionRequest +- description: + name: PretrainedModel +- description: + name: CompletionResponse +- description: + name: ChatCompletionRequest +- description: 'Chat completion response event. + + + ' + name: ChatCompletionResponseEvent +- description: + name: ChatCompletionResponseEventType +- description: 'SSE-stream of these events. + + + ' + name: ChatCompletionResponseStreamChunk +- description: + name: ToolCallDelta +- description: + name: ToolCallParseStatus +- description: + name: CompletionRequest +- description: 'streamed completion response. + + + ' + name: CompletionResponseStreamChunk +- description: + name: AgenticSystemCreateRequest +- description: + name: AgenticSystemInstanceConfig +- description: + name: AgenticSystemToolDefinition +- description: + name: BuiltinShield +- description: + name: OnViolationAction +- description: + name: RestAPIExecutionConfig +- description: + name: RestAPIMethod +- description: + name: ShieldDefinition +- description: + name: AgenticSystemCreateResponse +- description: + name: AgenticSystemSessionCreateRequest +- description: + name: AgenticSystemSessionCreateResponse +- description: + name: AgenticSystemTurnCreateRequest +- description: 'Streamed agent execution response. + + + ' + name: AgenticSystemTurnResponseEvent +- description: + name: AgenticSystemTurnResponseStepCompletePayload +- description: + name: AgenticSystemTurnResponseStepProgressPayload +- description: + name: AgenticSystemTurnResponseStepStartPayload +- description: + name: AgenticSystemTurnResponseStreamChunk +- description: + name: AgenticSystemTurnResponseTurnCompletePayload +- description: + name: AgenticSystemTurnResponseTurnStartPayload +- description: + name: MemoryBankDocument +- description: + name: MemoryRetrievalStep +- description: + name: ModelInferenceStep +- description: + name: ShieldCallStep +- description: + name: ShieldResponse +- description: + name: ToolExecutionStep +- description: + name: ToolResponse +- description: 'A single turn in an interaction with an Agentic System. + + + ' + name: Turn +- description: 'Request to create a dataset. + + + ' + name: CreateDatasetRequest +- description: 'Dataset to be used for training or evaluating language models. + + + ' + name: TrainEvalDataset +- description: + name: TrainEvalDatasetColumnType +- description: 'A single session of an interaction with an Agentic System. + + + ' + name: Session +- description: 'Artifacts of a evaluation job. + + + ' + name: EvaluationJobArtifactsResponse +- description: + name: EvaluationJobLogStream +- description: + name: EvaluationJobStatusResponse +- description: + name: EvaluationJob +- description: + name: MemoryBank +- description: 'Artifacts of a finetuning job. + + + ' + name: PostTrainingJobArtifactsResponse +- description: 'Stream of logs from a finetuning job. + + + ' + name: PostTrainingJobLogStream +- description: + name: PostTrainingJobStatus +- description: 'Status of a finetuning job. + + + ' + name: PostTrainingJobStatusResponse +- description: + name: PostTrainingJob +- description: 'Request to evaluate question answering. + + + ' + name: EvaluateQuestionAnsweringRequest +- description: 'Request to evaluate summarization. + + + ' + name: EvaluateSummarizationRequest +- description: 'Request to evaluate text generation. + + + ' + name: EvaluateTextGenerationRequest +- description: + name: RewardModel +- description: 'Request to generate synthetic data. A small batch of prompts and a + filtering function + + + ' + name: SyntheticDataGenerationRequest +- description: + name: ScoredDialogGenerations +- description: + name: ScoredMessage +- description: 'Response from the synthetic data generation. Batch of (prompt, response, + score) tuples that pass the threshold. + + + ' + name: SyntheticDataGenerationResponse +- description: + name: DPOAlignmentConfig +- description: + name: OptimizerConfig +- description: 'Request to finetune a model. + + + ' + name: PostTrainingRLHFRequest +- description: + name: RLHFAlgorithm +- description: + name: TrainingConfig +- description: + name: DialogGenerations +- description: 'Request to score a reward function. A list of prompts and a list of + responses per prompt. + + + ' + name: RewardScoringRequest +- description: 'Response from the reward scoring. Batch of (prompt, response, score) + tuples that pass the threshold. + + + ' + name: RewardScoringResponse +- description: + name: DoraFinetuningConfig +- description: + name: FinetuningAlgorithm +- description: + name: LoraFinetuningConfig +- description: 'Request to finetune a model. + + + ' + name: PostTrainingSFTRequest +- description: + name: QLoraFinetuningConfig +x-tagGroups: +- name: Operations + tags: + - AgenticSystem + - Datasets + - Evaluations + - MemoryBanks + - ModelInference + - PostTraining + - RewardScoring + - SyntheticDataGeneration +- name: Types + tags: + - AgenticSystemCreateRequest + - AgenticSystemCreateResponse + - AgenticSystemInstanceConfig + - AgenticSystemSessionCreateRequest + - AgenticSystemSessionCreateResponse + - AgenticSystemToolDefinition + - AgenticSystemTurnCreateRequest + - AgenticSystemTurnResponseEvent + - AgenticSystemTurnResponseStepCompletePayload + - AgenticSystemTurnResponseStepProgressPayload + - AgenticSystemTurnResponseStepStartPayload + - AgenticSystemTurnResponseStreamChunk + - AgenticSystemTurnResponseTurnCompletePayload + - AgenticSystemTurnResponseTurnStartPayload + - Attachment + - BatchChatCompletionRequest + - BatchCompletionRequest + - Bf16QuantizationConfig + - BuiltinShield + - ChatCompletionRequest + - ChatCompletionResponse + - ChatCompletionResponseEvent + - ChatCompletionResponseEventType + - ChatCompletionResponseStreamChunk + - CompletionMessage + - CompletionRequest + - CompletionResponse + - CompletionResponseStreamChunk + - CreateDatasetRequest + - DPOAlignmentConfig + - DialogGenerations + - DoraFinetuningConfig + - EvaluateQuestionAnsweringRequest + - EvaluateSummarizationRequest + - EvaluateTextGenerationRequest + - EvaluationJob + - EvaluationJobArtifactsResponse + - EvaluationJobLogStream + - EvaluationJobStatusResponse + - FinetuningAlgorithm + - Fp8QuantizationConfig + - InstructModel + - LoraFinetuningConfig + - MemoryBank + - MemoryBankDocument + - MemoryRetrievalStep + - ModelInferenceStep + - OnViolationAction + - OptimizerConfig + - PostTrainingJob + - PostTrainingJobArtifactsResponse + - PostTrainingJobLogStream + - PostTrainingJobStatus + - PostTrainingJobStatusResponse + - PostTrainingRLHFRequest + - PostTrainingSFTRequest + - PretrainedModel + - QLoraFinetuningConfig + - RLHFAlgorithm + - RestAPIExecutionConfig + - RestAPIMethod + - RewardModel + - RewardScoringRequest + - RewardScoringResponse + - SamplingParams + - SamplingStrategy + - ScoredDialogGenerations + - ScoredMessage + - Session + - ShieldCallStep + - ShieldDefinition + - ShieldResponse + - StopReason + - SyntheticDataGenerationRequest + - SyntheticDataGenerationResponse + - SystemMessage + - TokenLogProbs + - ToolCall + - ToolCallDelta + - ToolCallParseStatus + - ToolDefinition + - ToolExecutionStep + - ToolParamDefinition + - ToolResponse + - ToolResponseMessage + - TrainEvalDataset + - TrainEvalDatasetColumnType + - TrainingConfig + - Turn + - URL + - UserMessage diff --git a/toolchain/spec/package.sh b/toolchain/spec/package.sh new file mode 100644 index 000000000..856d619ba --- /dev/null +++ b/toolchain/spec/package.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -euo pipefail + +TMPDIR=$(mktemp -d) +echo "Using temporary directory: $TMPDIR" + +rootdir=$(git rev-parse --show-toplevel) + +files_to_copy=("toolchain/spec/openapi*" "models/llama3/datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py") +for file in "${files_to_copy[@]}"; do + relpath="$file" + set -x + mkdir -p "$TMPDIR/$(dirname $relpath)" + eval cp "$rootdir/$relpath" "$TMPDIR/$(dirname $relpath)" + set +x +done + +cd "$TMPDIR" +zip -r output.zip . + +echo "Zip at: $TMPDIR/output.zip" diff --git a/toolchain/spec/post_training_types.py b/toolchain/spec/post_training_types.py new file mode 100644 index 000000000..180a38bd6 --- /dev/null +++ b/toolchain/spec/post_training_types.py @@ -0,0 +1,105 @@ +from enum import Enum +from typing import Any, Dict, List + +from models.llama3.datatypes import URL +from pydantic import BaseModel, Field + +from strong_typing.schema import json_schema_type + + +class TrainEvalDatasetColumnType(Enum): + dialog = "dialog" + text = "text" + media = "media" + number = "number" + json = "json" + + +@json_schema_type +class TrainEvalDataset(BaseModel): + """Dataset to be used for training or evaluating language models.""" + + # TODO(ashwin): figure out if we need to add an enum for a "dataset type" + + columns: Dict[str, TrainEvalDatasetColumnType] + content_url: URL + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class OptimizerType(Enum): + adam = "adam" + adamw = "adamw" + sgd = "sgd" + + +@json_schema_type +class OptimizerConfig(BaseModel): + optimizer_type: OptimizerType + lr: float + lr_min: float + weight_decay: float + + +@json_schema_type +class TrainingConfig(BaseModel): + n_epochs: int + batch_size: int + shuffle: bool + n_iters: int + + enable_activation_checkpointing: bool + memory_efficient_fsdp_wrap: bool + fsdp_cpu_offload: bool + + +class FinetuningAlgorithm(Enum): + full = "full" + lora = "lora" + qlora = "qlora" + dora = "dora" + + +@json_schema_type +class LoraFinetuningConfig(BaseModel): + lora_attn_modules: List[str] + apply_lora_to_mlp: bool + apply_lora_to_output: bool + rank: int + alpha: int + + +@json_schema_type +class QLoraFinetuningConfig(LoraFinetuningConfig): + pass + + +@json_schema_type +class DoraFinetuningConfig(LoraFinetuningConfig): + pass + + +@json_schema_type +class PostTrainingJobLogStream(BaseModel): + """Stream of logs from a finetuning job.""" + + job_uuid: str + log_lines: List[str] + + +class PostTrainingJobStatus(Enum): + running = "running" + completed = "completed" + failed = "failed" + scheduled = "scheduled" + + +class RLHFAlgorithm(Enum): + dpo = "dpo" + + +@json_schema_type +class DPOAlignmentConfig(BaseModel): + reward_scale: float + reward_clip: float + epsilon: float + gamma: float diff --git a/toolchain/spec/run_openapi_generator.sh b/toolchain/spec/run_openapi_generator.sh new file mode 100644 index 000000000..7ebe54253 --- /dev/null +++ b/toolchain/spec/run_openapi_generator.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +set -x + +PYTHONPATH=../../../oss-ops:../.. python3 -m toolchain.spec.generate diff --git a/toolchain/synthetic_data_generation/api/__init__.py b/toolchain/synthetic_data_generation/api/__init__.py new file mode 100644 index 000000000..38413ff60 --- /dev/null +++ b/toolchain/synthetic_data_generation/api/__init__.py @@ -0,0 +1,2 @@ +from .datatypes import * # noqa: F401 F403 +from .endpoints import * # noqa: F401 F403 diff --git a/toolchain/synthetic_data_generation/api/datatypes.py b/toolchain/synthetic_data_generation/api/datatypes.py new file mode 100644 index 000000000..fd53a74a3 --- /dev/null +++ b/toolchain/synthetic_data_generation/api/datatypes.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class FilteringFunction(Enum): + """The type of filtering function.""" + + none = "none" + random = "random" + top_k = "top_k" + top_p = "top_p" + top_k_top_p = "top_k_top_p" + sigmoid = "sigmoid" diff --git a/toolchain/synthetic_data_generation/api/endpoints.py b/toolchain/synthetic_data_generation/api/endpoints.py new file mode 100644 index 000000000..424177a3a --- /dev/null +++ b/toolchain/synthetic_data_generation/api/endpoints.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, List, Optional, Protocol + +from pydantic import BaseModel + +from pyopenapi import webmethod +from strong_typing.schema import json_schema_type + +from models.llama3.datatypes import * # noqa: F403 +from toolchain.reward_scoring.api.datatypes import * # noqa: F403 +from .datatypes import * # noqa: F403 + + +@json_schema_type +class SyntheticDataGenerationRequest(BaseModel): + """Request to generate synthetic data. A small batch of prompts and a filtering function""" + + dialogs: List[Message] + filtering_function: FilteringFunction = FilteringFunction.none + model: Optional[RewardModel] = None + + +@json_schema_type +class SyntheticDataGenerationResponse(BaseModel): + """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" + + synthetic_data: List[ScoredDialogGenerations] + statistics: Optional[Dict[str, Any]] = None + + +class SyntheticDataGeneration(Protocol): + @webmethod(route="/synthetic_data_generation/generate") + def post_generate( + self, + request: SyntheticDataGenerationRequest, + ) -> Union[SyntheticDataGenerationResponse]: ... diff --git a/toolchain/utils.py b/toolchain/utils.py new file mode 100644 index 000000000..b8c91f529 --- /dev/null +++ b/toolchain/utils.py @@ -0,0 +1,55 @@ +import getpass +import os +from typing import Optional + +from hydra import compose, initialize, MissingConfigException +from hydra.core.global_hydra import GlobalHydra + +from omegaconf import OmegaConf + + +def get_root_directory(): + current_dir = os.path.dirname(os.path.abspath(__file__)) + while os.path.isfile(os.path.join(current_dir, "__init__.py")): + current_dir = os.path.dirname(current_dir) + + return current_dir + + +def get_config_dir(): + return os.path.join(get_root_directory(), "toolchain", "configs") + + +def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: + # Configs can be + # 1. relative paths in {config_dir}/ + # 2. or default to file {config_dir}/{user}.yaml + # 3. or ultimate default to {config_dir}/default.yaml + + # Get the relative path from the current file to the config directory + current_file_directory = os.path.dirname(os.path.abspath(__file__)) + relative_path = os.path.relpath(config_dir, current_file_directory) + + GlobalHydra.instance().clear() + initialize(config_path=relative_path) + + if config_path is None: + try: + user = getpass.getuser() + config_name = user + except MissingConfigException: + print(f"No user-specific {user}.yaml, using default") + config_name = "default" + else: + config_name = config_path + + config_abs_path = os.path.abspath(os.path.join(config_dir, f"{config_name}.yaml")) + print(f"Loading config from : {config_abs_path}") + config = compose(config_name=config_name) + + print("Yaml config:") + print("------------------------") + print(OmegaConf.to_yaml(config, resolve=True)) + print("------------------------") + + return config