diff --git a/toolchain/cli/download.py b/toolchain/cli/download.py index 10889992e..ab8e96b7f 100644 --- a/toolchain/cli/download.py +++ b/toolchain/cli/download.py @@ -7,9 +7,10 @@ from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from toolchain.cli.subcommand import Subcommand +from toolchain.utils import DEFAULT_DUMP_DIR -DEFAULT_CHECKPOINT_DIR = f"{os.path.expanduser('~')}/.llama/checkpoints/" +DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints") class Download(Subcommand): @@ -61,10 +62,8 @@ class Download(Subcommand): def _run_download_cmd(self, args: argparse.Namespace): model_name = args.repo_id.split("/")[-1] - + output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name os.makedirs(output_dir, exist_ok=True) - - output_dir = Path(output_dir) / model_name try: true_output_dir = snapshot_download( args.repo_id, diff --git a/toolchain/cli/inference/configure.py b/toolchain/cli/inference/configure.py index 9c728f6d8..df97ebf04 100644 --- a/toolchain/cli/inference/configure.py +++ b/toolchain/cli/inference/configure.py @@ -5,9 +5,11 @@ import textwrap from pathlib import Path from toolchain.cli.subcommand import Subcommand +from toolchain.utils import DEFAULT_DUMP_DIR -CONFIGS_BASE_DIR = f"{os.path.expanduser('~')}/.llama/configs/" +CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs") + class InferenceConfigure(Subcommand): """Llama cli for configuring llama toolchain configs""" diff --git a/toolchain/cli/inference/start.py b/toolchain/cli/inference/start.py index ab447b644..c105d4198 100644 --- a/toolchain/cli/inference/start.py +++ b/toolchain/cli/inference/start.py @@ -43,6 +43,7 @@ class InferenceStart(Subcommand): "--config", type=str, help="Path to config file", + default="inference" ) def _run_inference_start_cmd(self, args: argparse.Namespace) -> None: diff --git a/toolchain/inference/server.py b/toolchain/inference/server.py index 52aac3dda..a2846f136 100644 --- a/toolchain/inference/server.py +++ b/toolchain/inference/server.py @@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse from omegaconf import OmegaConf -from toolchain.utils import get_config_dir, parse_config +from toolchain.utils import get_default_config_dir, parse_config from .api.config import ModelInferenceHydraConfig from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk @@ -100,7 +100,7 @@ def chat_completion(request: Request, exec_request: ChatCompletionRequest): def main(config_path: str, port: int = 5000, disable_ipv6: bool = False): global GLOBAL_CONFIG - config_dir = get_config_dir() + config_dir = get_default_config_dir() GLOBAL_CONFIG = parse_config(config_dir, config_path) signal.signal(signal.SIGINT, handle_sigint) diff --git a/toolchain/utils.py b/toolchain/utils.py index b8c91f529..dbc72b0c8 100644 --- a/toolchain/utils.py +++ b/toolchain/utils.py @@ -8,6 +8,9 @@ from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf +DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/") + + def get_root_directory(): current_dir = os.path.dirname(os.path.abspath(__file__)) while os.path.isfile(os.path.join(current_dir, "__init__.py")): @@ -16,8 +19,8 @@ def get_root_directory(): return current_dir -def get_config_dir(): - return os.path.join(get_root_directory(), "toolchain", "configs") +def get_default_config_dir(): + return os.path.join(DEFAULT_DUMP_DIR, "configs") def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: