From c64b8cba22a248268a9c7a5e77532ebd5783e829 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sun, 21 Jul 2024 19:07:02 -0700 Subject: [PATCH] from models.llama3_1 --> from llama_models.llama3_1 --- requirements.txt | 2 ++ toolchain/common/deployment_types.py | 2 +- toolchain/common/training_types.py | 2 +- toolchain/dataset/api/datatypes.py | 2 +- toolchain/evaluations/api/endpoints.py | 2 +- toolchain/inference/api/datatypes.py | 2 +- toolchain/inference/generation.py | 10 +++++----- toolchain/inference/inference.py | 2 +- toolchain/inference/model_parallel.py | 6 +++--- toolchain/inference/quantization/loader.py | 2 +- toolchain/post_training/api/endpoints.py | 2 +- toolchain/reward_scoring/api/datatypes.py | 2 +- toolchain/safety/api/datatypes.py | 2 +- toolchain/safety/shields/base.py | 2 +- toolchain/safety/shields/contrib/third_party_shield.py | 2 +- toolchain/safety/shields/llama_guard.py | 2 +- toolchain/safety/shields/prompt_guard.py | 2 +- toolchain/safety/shields/shield_runner.py | 2 +- toolchain/spec/generate.py | 2 +- toolchain/spec/package.sh | 2 +- toolchain/spec/post_training_types.py | 2 +- toolchain/synthetic_data_generation/api/endpoints.py | 2 +- 22 files changed, 29 insertions(+), 27 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0d7081365..514c3ffa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,5 @@ ufmt==2.7.0 usort==1.0.8 uvicorn zmq + +llama_models[llama3_1] @ git+https://github.com/meta-llama/llama-models.git diff --git a/toolchain/common/deployment_types.py b/toolchain/common/deployment_types.py index 3a5d8bf4d..5abd7d991 100644 --- a/toolchain/common/deployment_types.py +++ b/toolchain/common/deployment_types.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Dict, Optional -from models.llama3_1.api.datatypes import URL +from llama_models.llama3_1.api.datatypes import URL from pydantic import BaseModel diff --git a/toolchain/common/training_types.py b/toolchain/common/training_types.py index 8c521ed02..c500bc91c 100644 --- a/toolchain/common/training_types.py +++ b/toolchain/common/training_types.py @@ -1,4 +1,4 @@ -from models.llama3_1.api.datatypes import URL +from llama_models.llama3_1.api.datatypes import URL from pydantic import BaseModel diff --git a/toolchain/dataset/api/datatypes.py b/toolchain/dataset/api/datatypes.py index cd4dd3726..260e68acb 100644 --- a/toolchain/dataset/api/datatypes.py +++ b/toolchain/dataset/api/datatypes.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, Optional -from models.llama3_1.api.datatypes import URL +from llama_models.llama3_1.api.datatypes import URL from pydantic import BaseModel diff --git a/toolchain/evaluations/api/endpoints.py b/toolchain/evaluations/api/endpoints.py index ff69baa02..bf3012635 100644 --- a/toolchain/evaluations/api/endpoints.py +++ b/toolchain/evaluations/api/endpoints.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from pyopenapi import webmethod -from models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3_1.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403 from toolchain.dataset.api.datatypes import * # noqa: F403 from toolchain.common.training_types import * # noqa: F403 diff --git a/toolchain/inference/api/datatypes.py b/toolchain/inference/api/datatypes.py index 90b9dfe73..3141a108e 100644 --- a/toolchain/inference/api/datatypes.py +++ b/toolchain/inference/api/datatypes.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type from typing_extensions import Annotated -from models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3_1.api.datatypes import * # noqa: F403 class LogProbConfig(BaseModel): diff --git a/toolchain/inference/generation.py b/toolchain/inference/generation.py index f714760ec..968c0e4d7 100644 --- a/toolchain/inference/generation.py +++ b/toolchain/inference/generation.py @@ -16,11 +16,11 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from models.llama3_1.api.args import ModelArgs -from models.llama3_1.api.chat_format import ChatFormat, ModelInput -from models.llama3_1.api.datatypes import Message -from models.llama3_1.api.model import Transformer -from models.llama3_1.api.tokenizer import Tokenizer +from llama_models.llama3_1.api.args import ModelArgs +from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput +from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3_1.api.model import Transformer +from llama_models.llama3_1.api.tokenizer import Tokenizer from termcolor import cprint from .api.config import CheckpointType, InlineImplConfig diff --git a/toolchain/inference/inference.py b/toolchain/inference/inference.py index 94228ac7b..48d15cea1 100644 --- a/toolchain/inference/inference.py +++ b/toolchain/inference/inference.py @@ -1,6 +1,6 @@ from typing import AsyncGenerator -from models.llama3_1.api.datatypes import StopReason +from llama_models.llama3_1.api.datatypes import StopReason from .api.config import ( CheckpointQuantizationFormat, diff --git a/toolchain/inference/model_parallel.py b/toolchain/inference/model_parallel.py index 2ffbe2fb0..2d9737a9c 100644 --- a/toolchain/inference/model_parallel.py +++ b/toolchain/inference/model_parallel.py @@ -2,9 +2,9 @@ from dataclasses import dataclass from functools import partial from typing import Generator, List, Optional -from models.llama3_1.api.chat_format import ChatFormat -from models.llama3_1.api.datatypes import Message -from models.llama3_1.api.tokenizer import Tokenizer +from llama_models.llama3_1.api.chat_format import ChatFormat +from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3_1.api.tokenizer import Tokenizer from .api.config import InlineImplConfig from .generation import Llama diff --git a/toolchain/inference/quantization/loader.py b/toolchain/inference/quantization/loader.py index 66b6b2ecc..d6d9fc89f 100644 --- a/toolchain/inference/quantization/loader.py +++ b/toolchain/inference/quantization/loader.py @@ -7,7 +7,7 @@ from typing import Optional import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from models.llama3_1.api.model import Transformer, TransformerBlock +from llama_models.llama3_1.api.model import Transformer, TransformerBlock from termcolor import cprint diff --git a/toolchain/post_training/api/endpoints.py b/toolchain/post_training/api/endpoints.py index 0114bb296..443a8027f 100644 --- a/toolchain/post_training/api/endpoints.py +++ b/toolchain/post_training/api/endpoints.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from pyopenapi import webmethod from strong_typing.schema import json_schema_type -from models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3_1.api.datatypes import * # noqa: F403 from toolchain.dataset.api.datatypes import * # noqa: F403 from toolchain.common.training_types import * # noqa: F403 from .datatypes import * # noqa: F403 diff --git a/toolchain/reward_scoring/api/datatypes.py b/toolchain/reward_scoring/api/datatypes.py index 36041d61d..f53d22861 100644 --- a/toolchain/reward_scoring/api/datatypes.py +++ b/toolchain/reward_scoring/api/datatypes.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from strong_typing.schema import json_schema_type -from models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3_1.api.datatypes import * # noqa: F403 @json_schema_type diff --git a/toolchain/safety/api/datatypes.py b/toolchain/safety/api/datatypes.py index 51a8f88c3..fcb665aa3 100644 --- a/toolchain/safety/api/datatypes.py +++ b/toolchain/safety/api/datatypes.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Dict, Optional, Union -from models.llama3_1.api.datatypes import ToolParamDefinition +from llama_models.llama3_1.api.datatypes import ToolParamDefinition from pydantic import BaseModel diff --git a/toolchain/safety/shields/base.py b/toolchain/safety/shields/base.py index dab64f6f9..33b5ba94c 100644 --- a/toolchain/safety/shields/base.py +++ b/toolchain/safety/shields/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Union -from models.llama3_1.api.datatypes import Attachment, Message +from llama_models.llama3_1.api.datatypes import Attachment, Message from toolchain.safety.api.datatypes import * # noqa: F403 CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/toolchain/safety/shields/contrib/third_party_shield.py b/toolchain/safety/shields/contrib/third_party_shield.py index d783a02b5..48e7414c5 100644 --- a/toolchain/safety/shields/contrib/third_party_shield.py +++ b/toolchain/safety/shields/contrib/third_party_shield.py @@ -1,7 +1,7 @@ import sys from typing import List -from models.llama3_1.api.datatypes import Message +from llama_models.llama3_1.api.datatypes import Message parent_dir = "../.." sys.path.append(parent_dir) diff --git a/toolchain/safety/shields/llama_guard.py b/toolchain/safety/shields/llama_guard.py index d8ae79ed8..8a6866601 100644 --- a/toolchain/safety/shields/llama_guard.py +++ b/toolchain/safety/shields/llama_guard.py @@ -4,7 +4,7 @@ from string import Template from typing import List, Optional import torch -from models.llama3_1.api.datatypes import Message +from llama_models.llama3_1.api.datatypes import Message from termcolor import cprint from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/toolchain/safety/shields/prompt_guard.py b/toolchain/safety/shields/prompt_guard.py index f67ee71c1..ddb539688 100644 --- a/toolchain/safety/shields/prompt_guard.py +++ b/toolchain/safety/shields/prompt_guard.py @@ -3,7 +3,7 @@ from typing import List import torch -from models.llama3_1.api.datatypes import Message +from llama_models.llama3_1.api.datatypes import Message from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer diff --git a/toolchain/safety/shields/shield_runner.py b/toolchain/safety/shields/shield_runner.py index b775092b7..27070d424 100644 --- a/toolchain/safety/shields/shield_runner.py +++ b/toolchain/safety/shields/shield_runner.py @@ -1,7 +1,7 @@ import asyncio from typing import List -from models.llama3_1.api.datatypes import Message, Role +from llama_models.llama3_1.api.datatypes import Message, Role from .base import OnViolationAction, ShieldBase, ShieldResponse diff --git a/toolchain/spec/generate.py b/toolchain/spec/generate.py index 974885b2b..6d5952038 100644 --- a/toolchain/spec/generate.py +++ b/toolchain/spec/generate.py @@ -4,7 +4,7 @@ import yaml from pyopenapi import Info, Options, Server, Specification -from models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3_1.api.datatypes import * # noqa: F403 from toolchain.dataset.api import * # noqa: F403 from toolchain.evaluations.api import * # noqa: F403 from toolchain.inference.api import * # noqa: F403 diff --git a/toolchain/spec/package.sh b/toolchain/spec/package.sh index d6f2e3fb3..854af6ffd 100644 --- a/toolchain/spec/package.sh +++ b/toolchain/spec/package.sh @@ -7,7 +7,7 @@ echo "Using temporary directory: $TMPDIR" rootdir=$(git rev-parse --show-toplevel) -files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py") +files_to_copy=("toolchain/spec/openapi*" "llama_models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py") for file in "${files_to_copy[@]}"; do relpath="$file" set -x diff --git a/toolchain/spec/post_training_types.py b/toolchain/spec/post_training_types.py index 6d55d1537..fc7d963cf 100644 --- a/toolchain/spec/post_training_types.py +++ b/toolchain/spec/post_training_types.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List -from models.llama3_1.api.datatypes import URL +from llama_models.llama3_1.api.datatypes import URL from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type diff --git a/toolchain/synthetic_data_generation/api/endpoints.py b/toolchain/synthetic_data_generation/api/endpoints.py index 82c71fb7e..9c3c5cccc 100644 --- a/toolchain/synthetic_data_generation/api/endpoints.py +++ b/toolchain/synthetic_data_generation/api/endpoints.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from pyopenapi import webmethod from strong_typing.schema import json_schema_type -from models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3_1.api.datatypes import * # noqa: F403 from toolchain.reward_scoring.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403