from models.llama3_1 --> from llama_models.llama3_1

This commit is contained in:
Hardik Shah 2024-07-21 19:07:02 -07:00
parent c6ef16f6bd
commit c64b8cba22
22 changed files with 29 additions and 27 deletions

View file

@ -27,3 +27,5 @@ ufmt==2.7.0
usort==1.0.8
uvicorn
zmq
llama_models[llama3_1] @ git+https://github.com/meta-llama/llama-models.git

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import Dict, Optional
from models.llama3_1.api.datatypes import URL
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel

View file

@ -1,4 +1,4 @@
from models.llama3_1.api.datatypes import URL
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, Optional
from models.llama3_1.api.datatypes import URL
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel
from pyopenapi import webmethod
from models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
from toolchain.dataset.api.datatypes import * # noqa: F403
from toolchain.common.training_types import * # noqa: F403

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3_1.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):

View file

@ -16,11 +16,11 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from models.llama3_1.api.args import ModelArgs
from models.llama3_1.api.chat_format import ChatFormat, ModelInput
from models.llama3_1.api.datatypes import Message
from models.llama3_1.api.model import Transformer
from models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint
from .api.config import CheckpointType, InlineImplConfig

View file

@ -1,6 +1,6 @@
from typing import AsyncGenerator
from models.llama3_1.api.datatypes import StopReason
from llama_models.llama3_1.api.datatypes import StopReason
from .api.config import (
CheckpointQuantizationFormat,

View file

@ -2,9 +2,9 @@ from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from models.llama3_1.api.chat_format import ChatFormat
from models.llama3_1.api.datatypes import Message
from models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from .api.config import InlineImplConfig
from .generation import Llama

View file

@ -7,7 +7,7 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from models.llama3_1.api.model import Transformer, TransformerBlock
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from termcolor import cprint

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from toolchain.dataset.api.datatypes import * # noqa: F403
from toolchain.common.training_types import * # noqa: F403
from .datatypes import * # noqa: F403

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3_1.api.datatypes import * # noqa: F403
@json_schema_type

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import Dict, Optional, Union
from models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from pydantic import BaseModel

View file

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Union
from models.llama3_1.api.datatypes import Attachment, Message
from llama_models.llama3_1.api.datatypes import Attachment, Message
from toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -1,7 +1,7 @@
import sys
from typing import List
from models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.datatypes import Message
parent_dir = "../.."
sys.path.append(parent_dir)

View file

@ -4,7 +4,7 @@ from string import Template
from typing import List, Optional
import torch
from models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForCausalLM, AutoTokenizer

View file

@ -3,7 +3,7 @@ from typing import List
import torch
from models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer

View file

@ -1,7 +1,7 @@
import asyncio
from typing import List
from models.llama3_1.api.datatypes import Message, Role
from llama_models.llama3_1.api.datatypes import Message, Role
from .base import OnViolationAction, ShieldBase, ShieldResponse

View file

@ -4,7 +4,7 @@ import yaml
from pyopenapi import Info, Options, Server, Specification
from models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from toolchain.dataset.api import * # noqa: F403
from toolchain.evaluations.api import * # noqa: F403
from toolchain.inference.api import * # noqa: F403

View file

@ -7,7 +7,7 @@ echo "Using temporary directory: $TMPDIR"
rootdir=$(git rev-parse --show-toplevel)
files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py")
files_to_copy=("toolchain/spec/openapi*" "llama_models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py")
for file in "${files_to_copy[@]}"; do
relpath="$file"
set -x

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List
from models.llama3_1.api.datatypes import URL
from llama_models.llama3_1.api.datatypes import URL
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type

View file

@ -5,7 +5,7 @@ from pydantic import BaseModel
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
from models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from toolchain.reward_scoring.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403