From 2ed2881a21cb7f62ec04a2c4c6736e2d8ceacc65 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Fri, 19 Jul 2024 17:42:14 -0700 Subject: [PATCH] fixed imports models.llama3. --> models.llama3_1.api. --- toolchain/common/deployment_types.py | 2 +- toolchain/common/training_types.py | 2 +- toolchain/dataset/api/datatypes.py | 2 +- toolchain/evaluations/api/endpoints.py | 2 +- toolchain/inference/api/datatypes.py | 2 +- toolchain/inference/generation.py | 10 +++++----- toolchain/inference/inference.py | 2 +- toolchain/inference/model_parallel.py | 6 +++--- toolchain/post_training/api/endpoints.py | 2 +- toolchain/reward_scoring/api/datatypes.py | 2 +- toolchain/safety/api/datatypes.py | 2 +- toolchain/safety/shields/base.py | 2 +- toolchain/safety/shields/contrib/third_party_shield.py | 2 +- toolchain/safety/shields/llama_guard.py | 2 +- toolchain/safety/shields/prompt_guard.py | 2 +- toolchain/safety/shields/shield_runner.py | 2 +- toolchain/spec/generate.py | 2 +- toolchain/spec/package.sh | 2 +- toolchain/spec/post_training_types.py | 2 +- toolchain/synthetic_data_generation/api/endpoints.py | 2 +- 20 files changed, 26 insertions(+), 26 deletions(-) diff --git a/toolchain/common/deployment_types.py b/toolchain/common/deployment_types.py index 88831f41c..3a5d8bf4d 100644 --- a/toolchain/common/deployment_types.py +++ b/toolchain/common/deployment_types.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Dict, Optional -from models.llama3.datatypes import URL +from models.llama3_1.api.datatypes import URL from pydantic import BaseModel diff --git a/toolchain/common/training_types.py b/toolchain/common/training_types.py index d5e756bba..8c521ed02 100644 --- a/toolchain/common/training_types.py +++ b/toolchain/common/training_types.py @@ -1,4 +1,4 @@ -from models.llama3.datatypes import URL +from models.llama3_1.api.datatypes import URL from pydantic import BaseModel diff --git a/toolchain/dataset/api/datatypes.py b/toolchain/dataset/api/datatypes.py index 4fed2a0a3..c0ec21e20 100644 --- a/toolchain/dataset/api/datatypes.py +++ b/toolchain/dataset/api/datatypes.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional from pydantic import BaseModel from strong_typing.schema import json_schema_type -from models.llama3.datatypes import URL +from models.llama3_1.api.datatypes import URL @json_schema_type diff --git a/toolchain/evaluations/api/endpoints.py b/toolchain/evaluations/api/endpoints.py index 7fef5d9da..b33594970 100644 --- a/toolchain/evaluations/api/endpoints.py +++ b/toolchain/evaluations/api/endpoints.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from pyopenapi import webmethod -from models.llama3.datatypes import * # noqa: F403 +from models.llama3_1.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403 from toolchain.dataset.api.datatypes import * # noqa: F403 from toolchain.common.training_types import * # noqa: F403 diff --git a/toolchain/inference/api/datatypes.py b/toolchain/inference/api/datatypes.py index ef54d869d..cd7c8a432 100644 --- a/toolchain/inference/api/datatypes.py +++ b/toolchain/inference/api/datatypes.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type from typing_extensions import Annotated -from models.llama3.datatypes import * # noqa: F403 +from models.llama3_1.api.datatypes import * # noqa: F403 class LogProbConfig(BaseModel): diff --git a/toolchain/inference/generation.py b/toolchain/inference/generation.py index 1a067071b..d9febf19e 100644 --- a/toolchain/inference/generation.py +++ b/toolchain/inference/generation.py @@ -16,11 +16,11 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from models.llama3.args import ModelArgs -from models.llama3.chat_format import ChatFormat, ModelInput -from models.llama3.datatypes import Message -from models.llama3.model import Transformer -from models.llama3.tokenizer import Tokenizer +from models.llama3_1.api.args import ModelArgs +from models.llama3_1.api.chat_format import ChatFormat, ModelInput +from models.llama3_1.api.datatypes import Message +from models.llama3_1.api.model import Transformer +from models.llama3_1.api.tokenizer import Tokenizer from termcolor import cprint diff --git a/toolchain/inference/inference.py b/toolchain/inference/inference.py index 7ef32e93b..5a117eb09 100644 --- a/toolchain/inference/inference.py +++ b/toolchain/inference/inference.py @@ -1,6 +1,6 @@ from typing import AsyncGenerator -from models.llama3.datatypes import StopReason +from models.llama3_1.api.datatypes import StopReason from .api.config import CheckpointType, GeneratorArgs, InlineImplConfig from .api.datatypes import ( diff --git a/toolchain/inference/model_parallel.py b/toolchain/inference/model_parallel.py index 8be991681..2a7fcf781 100644 --- a/toolchain/inference/model_parallel.py +++ b/toolchain/inference/model_parallel.py @@ -2,9 +2,9 @@ from dataclasses import dataclass from functools import partial from typing import Generator, List, Optional -from models.llama3.chat_format import ChatFormat -from models.llama3.datatypes import Message -from models.llama3.tokenizer import Tokenizer +from models.llama3_1.api.chat_format import ChatFormat +from models.llama3_1.api.datatypes import Message +from models.llama3_1.api.tokenizer import Tokenizer from .api.config import GeneratorArgs from .generation import Llama diff --git a/toolchain/post_training/api/endpoints.py b/toolchain/post_training/api/endpoints.py index 8bcaafc3b..0114bb296 100644 --- a/toolchain/post_training/api/endpoints.py +++ b/toolchain/post_training/api/endpoints.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from pyopenapi import webmethod from strong_typing.schema import json_schema_type -from models.llama3.datatypes import * # noqa: F403 +from models.llama3_1.api.datatypes import * # noqa: F403 from toolchain.dataset.api.datatypes import * # noqa: F403 from toolchain.common.training_types import * # noqa: F403 from .datatypes import * # noqa: F403 diff --git a/toolchain/reward_scoring/api/datatypes.py b/toolchain/reward_scoring/api/datatypes.py index c7acdb1a3..36041d61d 100644 --- a/toolchain/reward_scoring/api/datatypes.py +++ b/toolchain/reward_scoring/api/datatypes.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from strong_typing.schema import json_schema_type -from models.llama3.datatypes import * # noqa: F403 +from models.llama3_1.api.datatypes import * # noqa: F403 @json_schema_type diff --git a/toolchain/safety/api/datatypes.py b/toolchain/safety/api/datatypes.py index 45866d026..72541d10f 100644 --- a/toolchain/safety/api/datatypes.py +++ b/toolchain/safety/api/datatypes.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Dict, Optional, Union -from models.llama3.datatypes import ToolParamDefinition +from models.llama3_1.api.datatypes import ToolParamDefinition from pydantic import BaseModel diff --git a/toolchain/safety/shields/base.py b/toolchain/safety/shields/base.py index dc7a04879..dab64f6f9 100644 --- a/toolchain/safety/shields/base.py +++ b/toolchain/safety/shields/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Union -from models.llama3.datatypes import Attachment, Message +from models.llama3_1.api.datatypes import Attachment, Message from toolchain.safety.api.datatypes import * # noqa: F403 CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/toolchain/safety/shields/contrib/third_party_shield.py b/toolchain/safety/shields/contrib/third_party_shield.py index 13ee9556b..d783a02b5 100644 --- a/toolchain/safety/shields/contrib/third_party_shield.py +++ b/toolchain/safety/shields/contrib/third_party_shield.py @@ -1,7 +1,7 @@ import sys from typing import List -from models.llama3.datatypes import Message +from models.llama3_1.api.datatypes import Message parent_dir = "../.." sys.path.append(parent_dir) diff --git a/toolchain/safety/shields/llama_guard.py b/toolchain/safety/shields/llama_guard.py index 4d8f7c95b..d8ae79ed8 100644 --- a/toolchain/safety/shields/llama_guard.py +++ b/toolchain/safety/shields/llama_guard.py @@ -4,7 +4,7 @@ from string import Template from typing import List, Optional import torch -from models.llama3.datatypes import Message +from models.llama3_1.api.datatypes import Message from termcolor import cprint from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/toolchain/safety/shields/prompt_guard.py b/toolchain/safety/shields/prompt_guard.py index e5082ff8c..4cab03b36 100644 --- a/toolchain/safety/shields/prompt_guard.py +++ b/toolchain/safety/shields/prompt_guard.py @@ -3,7 +3,7 @@ from typing import List import torch -from models.llama3.datatypes import Message +from models.llama3_1.api.datatypes import Message from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer diff --git a/toolchain/safety/shields/shield_runner.py b/toolchain/safety/shields/shield_runner.py index cb0c23302..b775092b7 100644 --- a/toolchain/safety/shields/shield_runner.py +++ b/toolchain/safety/shields/shield_runner.py @@ -1,7 +1,7 @@ import asyncio from typing import List -from models.llama3.datatypes import Message, Role +from models.llama3_1.api.datatypes import Message, Role from .base import OnViolationAction, ShieldBase, ShieldResponse diff --git a/toolchain/spec/generate.py b/toolchain/spec/generate.py index 8afb77cf1..5b4bd9e04 100644 --- a/toolchain/spec/generate.py +++ b/toolchain/spec/generate.py @@ -4,7 +4,7 @@ import yaml from pyopenapi import Info, Options, Server, Specification -from models.llama3.datatypes import * # noqa: F403 +from models.llama3_1.api.datatypes import * # noqa: F403 from toolchain.dataset.api import * # noqa: F403 from toolchain.evaluations.api import * # noqa: F403 from toolchain.inference.api import * # noqa: F403 diff --git a/toolchain/spec/package.sh b/toolchain/spec/package.sh index 856d619ba..a5272cef2 100644 --- a/toolchain/spec/package.sh +++ b/toolchain/spec/package.sh @@ -7,7 +7,7 @@ echo "Using temporary directory: $TMPDIR" rootdir=$(git rev-parse --show-toplevel) -files_to_copy=("toolchain/spec/openapi*" "models/llama3/datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py") +files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py") for file in "${files_to_copy[@]}"; do relpath="$file" set -x diff --git a/toolchain/spec/post_training_types.py b/toolchain/spec/post_training_types.py index 180a38bd6..6d55d1537 100644 --- a/toolchain/spec/post_training_types.py +++ b/toolchain/spec/post_training_types.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List -from models.llama3.datatypes import URL +from models.llama3_1.api.datatypes import URL from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type diff --git a/toolchain/synthetic_data_generation/api/endpoints.py b/toolchain/synthetic_data_generation/api/endpoints.py index 424177a3a..82c71fb7e 100644 --- a/toolchain/synthetic_data_generation/api/endpoints.py +++ b/toolchain/synthetic_data_generation/api/endpoints.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from pyopenapi import webmethod from strong_typing.schema import json_schema_type -from models.llama3.datatypes import * # noqa: F403 +from models.llama3_1.api.datatypes import * # noqa: F403 from toolchain.reward_scoring.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403