cli updates

This commit is contained in:
Hardik Shah 2024-07-21 01:51:54 -07:00
parent 23fe353e4a
commit c9f33d8f68
5 changed files with 14 additions and 9 deletions

View file

@ -7,9 +7,10 @@ from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from toolchain.cli.subcommand import Subcommand from toolchain.cli.subcommand import Subcommand
from toolchain.utils import DEFAULT_DUMP_DIR
DEFAULT_CHECKPOINT_DIR = f"{os.path.expanduser('~')}/.llama/checkpoints/" DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints")
class Download(Subcommand): class Download(Subcommand):
@ -61,10 +62,8 @@ class Download(Subcommand):
def _run_download_cmd(self, args: argparse.Namespace): def _run_download_cmd(self, args: argparse.Namespace):
model_name = args.repo_id.split("/")[-1] model_name = args.repo_id.split("/")[-1]
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
output_dir = Path(output_dir) / model_name
try: try:
true_output_dir = snapshot_download( true_output_dir = snapshot_download(
args.repo_id, args.repo_id,

View file

@ -5,9 +5,11 @@ import textwrap
from pathlib import Path from pathlib import Path
from toolchain.cli.subcommand import Subcommand from toolchain.cli.subcommand import Subcommand
from toolchain.utils import DEFAULT_DUMP_DIR
CONFIGS_BASE_DIR = f"{os.path.expanduser('~')}/.llama/configs/" CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs")
class InferenceConfigure(Subcommand): class InferenceConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs""" """Llama cli for configuring llama toolchain configs"""

View file

@ -43,6 +43,7 @@ class InferenceStart(Subcommand):
"--config", "--config",
type=str, type=str,
help="Path to config file", help="Path to config file",
default="inference"
) )
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None: def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:

View file

@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
from omegaconf import OmegaConf from omegaconf import OmegaConf
from toolchain.utils import get_config_dir, parse_config from toolchain.utils import get_default_config_dir, parse_config
from .api.config import ModelInferenceHydraConfig from .api.config import ModelInferenceHydraConfig
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
@ -100,7 +100,7 @@ def chat_completion(request: Request, exec_request: ChatCompletionRequest):
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False): def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
global GLOBAL_CONFIG global GLOBAL_CONFIG
config_dir = get_config_dir() config_dir = get_default_config_dir()
GLOBAL_CONFIG = parse_config(config_dir, config_path) GLOBAL_CONFIG = parse_config(config_dir, config_path)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)

View file

@ -8,6 +8,9 @@ from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf from omegaconf import OmegaConf
DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/")
def get_root_directory(): def get_root_directory():
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.isfile(os.path.join(current_dir, "__init__.py")): while os.path.isfile(os.path.join(current_dir, "__init__.py")):
@ -16,8 +19,8 @@ def get_root_directory():
return current_dir return current_dir
def get_config_dir(): def get_default_config_dir():
return os.path.join(get_root_directory(), "toolchain", "configs") return os.path.join(DEFAULT_DUMP_DIR, "configs")
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: