mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
cli updates
This commit is contained in:
parent
23fe353e4a
commit
c9f33d8f68
5 changed files with 14 additions and 9 deletions
|
@ -7,9 +7,10 @@ from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||||
|
|
||||||
from toolchain.cli.subcommand import Subcommand
|
from toolchain.cli.subcommand import Subcommand
|
||||||
|
from toolchain.utils import DEFAULT_DUMP_DIR
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR = f"{os.path.expanduser('~')}/.llama/checkpoints/"
|
DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints")
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
|
@ -61,10 +62,8 @@ class Download(Subcommand):
|
||||||
|
|
||||||
def _run_download_cmd(self, args: argparse.Namespace):
|
def _run_download_cmd(self, args: argparse.Namespace):
|
||||||
model_name = args.repo_id.split("/")[-1]
|
model_name = args.repo_id.split("/")[-1]
|
||||||
|
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model_name
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
output_dir = Path(output_dir) / model_name
|
|
||||||
try:
|
try:
|
||||||
true_output_dir = snapshot_download(
|
true_output_dir = snapshot_download(
|
||||||
args.repo_id,
|
args.repo_id,
|
||||||
|
|
|
@ -5,9 +5,11 @@ import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from toolchain.cli.subcommand import Subcommand
|
from toolchain.cli.subcommand import Subcommand
|
||||||
|
from toolchain.utils import DEFAULT_DUMP_DIR
|
||||||
|
|
||||||
|
|
||||||
CONFIGS_BASE_DIR = f"{os.path.expanduser('~')}/.llama/configs/"
|
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs")
|
||||||
|
|
||||||
|
|
||||||
class InferenceConfigure(Subcommand):
|
class InferenceConfigure(Subcommand):
|
||||||
"""Llama cli for configuring llama toolchain configs"""
|
"""Llama cli for configuring llama toolchain configs"""
|
||||||
|
|
|
@ -43,6 +43,7 @@ class InferenceStart(Subcommand):
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to config file",
|
help="Path to config file",
|
||||||
|
default="inference"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
|
def _run_inference_start_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from toolchain.utils import get_config_dir, parse_config
|
from toolchain.utils import get_default_config_dir, parse_config
|
||||||
from .api.config import ModelInferenceHydraConfig
|
from .api.config import ModelInferenceHydraConfig
|
||||||
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
|
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ def chat_completion(request: Request, exec_request: ChatCompletionRequest):
|
||||||
|
|
||||||
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
|
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
global GLOBAL_CONFIG
|
global GLOBAL_CONFIG
|
||||||
config_dir = get_config_dir()
|
config_dir = get_default_config_dir()
|
||||||
GLOBAL_CONFIG = parse_config(config_dir, config_path)
|
GLOBAL_CONFIG = parse_config(config_dir, config_path)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
|
@ -8,6 +8,9 @@ from hydra.core.global_hydra import GlobalHydra
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/")
|
||||||
|
|
||||||
|
|
||||||
def get_root_directory():
|
def get_root_directory():
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
while os.path.isfile(os.path.join(current_dir, "__init__.py")):
|
while os.path.isfile(os.path.join(current_dir, "__init__.py")):
|
||||||
|
@ -16,8 +19,8 @@ def get_root_directory():
|
||||||
return current_dir
|
return current_dir
|
||||||
|
|
||||||
|
|
||||||
def get_config_dir():
|
def get_default_config_dir():
|
||||||
return os.path.join(get_root_directory(), "toolchain", "configs")
|
return os.path.join(DEFAULT_DUMP_DIR, "configs")
|
||||||
|
|
||||||
|
|
||||||
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
|
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue