diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index 1e0712b4a..5a0ff4be9 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -5,18 +5,24 @@ # the root directory of this source tree. import argparse -import os +import importlib +import inspect +import shlex from pathlib import Path +from typing import Annotated, get_args, get_origin, Literal, Union -import pkg_resources +import yaml +from pydantic import BaseModel +from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.distribution.registry import all_registered_distributions -from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR +from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter +from llama_toolchain.distribution.registry import available_distributions +from llama_toolchain.utils import DISTRIBS_BASE_DIR +from .utils import run_command - -CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs") +DISTRIBS = available_distributions() class DistributionConfigure(Subcommand): @@ -34,59 +40,198 @@ class DistributionConfigure(Subcommand): self.parser.set_defaults(func=self._run_distribution_configure_cmd) def _add_arguments(self): - distribs = all_registered_distributions() self.parser.add_argument( "--name", type=str, help="Mame of the distribution to configure", default="local-source", - choices=[d.name for d in distribs], + choices=[d.name for d in available_distributions()], ) - def read_user_inputs(self): - checkpoint_dir = input( - "Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): " - ) - model_parallel_size = input( - "Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): " - ) - assert model_parallel_size.isdigit() and int(model_parallel_size) in { - 1, - 8, - }, "model parallel size must be 1 or 8" - - return checkpoint_dir, model_parallel_size - - def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path): - default_conf_path = pkg_resources.resource_filename( - "llama_toolchain", "data/default_distribution_config.yaml" - ) - with open(default_conf_path, "r") as f: - yaml_content = f.read() - - yaml_content = yaml_content.format( - checkpoint_dir=checkpoint_dir, - model_parallel_size=model_parallel_size, - ) - - with open(yaml_output_path, "w") as yaml_file: - yaml_file.write(yaml_content.strip()) - - print(f"YAML configuration has been written to {yaml_output_path}") - def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: - checkpoint_dir, model_parallel_size = self.read_user_inputs() - checkpoint_dir = os.path.expanduser(checkpoint_dir) + dist = None + for d in DISTRIBS: + if d.name == args.name: + dist = d + break - assert ( - Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir() - ), f"{checkpoint_dir} does not exist or it not a directory" + if dist is None: + self.parser.error(f"Could not find distribution {args.name}") + return - os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) - yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml" + env_file = DISTRIBS_BASE_DIR / dist.name / "conda.env" + # read this file to get the conda env name + assert env_file.exists(), f"Could not find conda env file {env_file}" + with open(env_file, "r") as f: + conda_env = f.read().strip() - self.write_output_yaml( - checkpoint_dir, - model_parallel_size, - yaml_output_path, + configure_llama_distribution(dist, conda_env) + + +def configure_llama_distribution(dist: Distribution, conda_env: str): + python_exe = run_command(shlex.split("which python")) + # simple check + if conda_env not in python_exe: + raise ValueError( + f"Please re-run configure by activating the `{conda_env}` conda environment" ) + + adapter_configs = {} + for api_surface, adapter in dist.adapters.items(): + if isinstance(adapter, PassthroughApiAdapter): + adapter_configs[api_surface.value] = adapter.dict() + else: + cprint( + f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"] + ) + config_type = instantiate_class_type(adapter.config_class) + # TODO: when we are re-configuring, we should read existing values + config = prompt_for_config(config_type) + adapter_configs[api_surface.value] = config.dict() + + dist_config = { + "adapters": adapter_configs, + "conda_env": conda_env, + } + + yaml_output_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml" + with open(yaml_output_path, "w") as fp: + fp.write(yaml.dump(dist_config, sort_keys=False)) + + print(f"YAML configuration has been written to {yaml_output_path}") + + +def instantiate_class_type(fully_qualified_name): + module_name, class_name = fully_qualified_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +def get_literal_values(field): + """Extract literal values from a field if it's a Literal type.""" + if get_origin(field.annotation) is Literal: + return get_args(field.annotation) + return None + + +def is_optional(field_type): + """Check if a field type is Optional.""" + return get_origin(field_type) is Union and type(None) in get_args(field_type) + + +def get_non_none_type(field_type): + """Get the non-None type from an Optional type.""" + return next(arg for arg in get_args(field_type) if arg is not type(None)) + + +def prompt_for_config(config_type: type[BaseModel]) -> BaseModel: + """ + Recursively prompt the user for configuration values based on a Pydantic BaseModel. + + Args: + config_type: A Pydantic BaseModel class representing the configuration structure. + + Returns: + An instance of the config_type with user-provided values. + """ + config_data = {} + + for field_name, field in config_type.__fields__.items(): + field_type = field.annotation + default_value = ( + field.default if not isinstance(field.default, type(Ellipsis)) else None + ) + is_required = field.required + + # Skip fields with Literal type + if get_origin(field_type) is Literal: + continue + + # Check if the field is a discriminated union + if get_origin(field_type) is Annotated: + inner_type = get_args(field_type)[0] + if get_origin(inner_type) is Union: + discriminator = field.field_info.discriminator + if discriminator: + union_types = get_args(inner_type) + # Find the discriminator field in each union type + type_map = {} + for t in union_types: + disc_field = t.__fields__[discriminator] + literal_values = get_literal_values(disc_field) + if literal_values: + for value in literal_values: + type_map[value] = t + + while True: + discriminator_value = input( + f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): " + ) + if discriminator_value in type_map: + chosen_type = type_map[discriminator_value] + print(f"\nConfiguring {chosen_type.__name__}:") + sub_config = prompt_for_config(chosen_type) + config_data[field_name] = sub_config + # Set the discriminator field in the sub-config + setattr(sub_config, discriminator, discriminator_value) + break + else: + print(f"Invalid {discriminator}. Please try again.") + continue + + if inspect.isclass(field_type) and issubclass(field_type, BaseModel): + print(f"\nEntering sub-configuration for {field_name}:") + config_data[field_name] = prompt_for_config(field_type) + else: + prompt = f"Enter value for {field_name}" + if default_value is not None: + prompt += f" (default: {default_value})" + if is_optional(field_type): + prompt += " (optional)" + elif is_required: + prompt += " (required)" + prompt += ": " + + while True: + user_input = input(prompt) + if user_input == "": + if default_value is not None: + config_data[field_name] = default_value + break + elif is_optional(field_type): + config_data[field_name] = None + break + elif not is_required: + config_data[field_name] = None + break + else: + print("This field is required. Please provide a value.") + continue + + try: + # Handle Optional types + if is_optional(field_type): + if user_input.lower() == "none": + config_data[field_name] = None + break + field_type = get_non_none_type(field_type) + + # Convert the input to the correct type + if inspect.isclass(field_type) and issubclass( + field_type, BaseModel + ): + # For nested BaseModels, we assume a dictionary-like string input + import ast + + config_data[field_name] = field_type( + **ast.literal_eval(user_input) + ) + else: + config_data[field_name] = field_type(user_input) + break + except ValueError: + print( + f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}" + ) + + return config_type(**config_data) diff --git a/llama_toolchain/cli/distribution/distribution.py b/llama_toolchain/cli/distribution/distribution.py index 02a0b8caf..c553dcf3b 100644 --- a/llama_toolchain/cli/distribution/distribution.py +++ b/llama_toolchain/cli/distribution/distribution.py @@ -7,6 +7,7 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand + from .configure import DistributionConfigure from .create import DistributionCreate from .install import DistributionInstall diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index d8bbcb599..df60e7ad1 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -7,20 +7,16 @@ import argparse import os import shlex -import subprocess - -from pathlib import Path import pkg_resources from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.distribution.registry import all_registered_distributions -from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR +from llama_toolchain.distribution.datatypes import distribution_dependencies +from llama_toolchain.distribution.registry import available_distributions +from llama_toolchain.utils import DISTRIBS_BASE_DIR +from .utils import run_command, run_with_pty - -DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions" - -DISTRIBS = all_registered_distributions() +DISTRIBS = available_distributions() class DistributionInstall(Subcommand): @@ -70,13 +66,19 @@ class DistributionInstall(Subcommand): return os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True) - run_shell_script(script, args.conda_env, " ".join(dist.pip_packages)) + + deps = distribution_dependencies(dist) + run_command([script, args.conda_env, " ".join(deps)]) with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f: f.write(f"{args.conda_env}\n") - -def run_shell_script(script_path, *args): - command_string = f"{script_path} {' '.join(shlex.quote(str(arg)) for arg in args)}" - command_list = shlex.split(command_string) - print(f"Running command: {command_list}") - subprocess.run(command_list, check=True, text=True) + # we need to run configure _within_ the conda environment and need to run with + # a pty since configure is + python_exe = run_command( + shlex.split(f"conda run -n {args.conda_env} which python") + ).strip() + run_with_pty( + shlex.split( + f"{python_exe} -m llama_toolchain.cli.llama distribution configure --name {dist.name}" + ) + ) diff --git a/llama_toolchain/cli/distribution/list.py b/llama_toolchain/cli/distribution/list.py index 4cf26980b..39e93d8ec 100644 --- a/llama_toolchain/cli/distribution/list.py +++ b/llama_toolchain/cli/distribution/list.py @@ -9,7 +9,8 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.table import print_table -from llama_toolchain.distribution.registry import all_registered_distributions +from llama_toolchain.distribution.datatypes import distribution_dependencies +from llama_toolchain.distribution.registry import available_distributions class DistributionList(Subcommand): @@ -37,12 +38,13 @@ class DistributionList(Subcommand): ] rows = [] - for dist in all_registered_distributions(): + for dist in available_distributions(): + deps = distribution_dependencies(dist) rows.append( [ dist.name, dist.description, - ", ".join(dist.pip_packages), + ", ".join(deps), ] ) print_table( diff --git a/llama_toolchain/cli/distribution/utils.py b/llama_toolchain/cli/distribution/utils.py new file mode 100644 index 000000000..94ed1b0bb --- /dev/null +++ b/llama_toolchain/cli/distribution/utils.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import errno +import os +import pty +import select +import subprocess +import sys +import termios +import tty + + +def run_with_pty(command): + old_settings = termios.tcgetattr(sys.stdin) + + # Create a new pseudo-terminal + master, slave = pty.openpty() + + try: + # ensure the terminal does not echo input + tty.setraw(sys.stdin.fileno()) + + process = subprocess.Popen( + command, + stdin=slave, + stdout=slave, + stderr=slave, + universal_newlines=True, + ) + + # Close the slave file descriptor as it's now owned by the subprocess + os.close(slave) + + def handle_io(): + while True: + rlist, _, _ = select.select([sys.stdin, master], [], []) + + if sys.stdin in rlist: + data = os.read(sys.stdin.fileno(), 1024) + if not data: # EOF + break + os.write(master, data) + + if master in rlist: + data = os.read(master, 1024) + if not data: + break + os.write(sys.stdout.fileno(), data) + + handle_io() + except (EOFError, KeyboardInterrupt): + pass + except OSError as e: + if e.errno != errno.EIO: + raise + finally: + # Restore original terminal settings + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) + + process.wait() + os.close(master) + + return process.returncode + + +def run_command(command): + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + if process.returncode != 0: + print(f"Error: {error.decode('utf-8')}") + sys.exit(1) + return output.decode("utf-8") diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 5cef08dcc..4209e6d26 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -32,15 +32,7 @@ class PassthroughApiAdapterConfig(BaseModel): @json_schema_type class PythonImplAdapterConfig(BaseModel): type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value - pip_packages: List[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) - module: str = Field(..., description="The name of the module to import") - entrypoint: str = Field( - ..., - description="The name of the entrypoint function which creates the implementation for the API", - ) + adapter_id: str kwargs: Dict[str, Any] = Field( default_factory=dict, description="kwargs to pass to the entrypoint" ) @@ -63,21 +55,43 @@ AdapterConfig = Annotated[ ] -class DistributionConfig(BaseModel): - inference: AdapterConfig - safety: AdapterConfig - - # configs for each API that the stack provides, e.g. - # agentic_system: AdapterConfig - # post_training: AdapterConfig +@json_schema_type +class ApiSurface(Enum): + inference = "inference" + safety = "safety" -class DistributionConfigDefaults(BaseModel): - inference: Dict[str, Any] = Field( - default_factory=dict, description="Default kwargs for the inference adapter" +@json_schema_type +class Adapter(BaseModel): + api_surface: ApiSurface + adapter_id: str + + +@json_schema_type +class SourceAdapter(Adapter): + pip_packages: List[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", ) - safety: Dict[str, Any] = Field( - default_factory=dict, description="Default kwargs for the safety adapter" + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have +a `get_adapter_instance()` method which will be passed a validated config object +of type `config_class`.""", + ) + config_class: str = Field( + ..., + description="Fully-qualified classname of the config for this adapter", + ) + + +@json_schema_type +class PassthroughApiAdapter(Adapter): + base_url: str = Field(..., description="The base URL for the llama stack provider") + headers: Dict[str, str] = Field( + default_factory=dict, + description="Headers (e.g., authorization) to send with the request", ) @@ -85,9 +99,22 @@ class Distribution(BaseModel): name: str description: str - # you must install the packages to get the functionality needed. - # later, we may have a docker image be the main artifact of - # a distribution. - pip_packages: List[str] + adapters: Dict[ApiSurface, Adapter] = Field( + default_factory=dict, + description="The API surfaces provided by this distribution", + ) - config_defaults: DistributionConfigDefaults + additional_pip_packages: List[str] = Field( + default_factory=list, + description="Additional pip packages beyond those required by the adapters", + ) + + +def distribution_dependencies(distribution: Distribution) -> List[str]: + # only consider SourceAdapters when calculating dependencies + return [ + dep + for adapter in distribution.adapters.values() + if isinstance(adapter, SourceAdapter) + for dep in adapter.pip_packages + ] + distribution.additional_pip_packages diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index a1f9a7a55..1dec45861 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -6,30 +6,60 @@ from typing import List -from .datatypes import Distribution, DistributionConfigDefaults +from llama_toolchain.inference.adapters import available_inference_adapters + +from .datatypes import ApiSurface, Distribution + +# This is currently duplicated from `requirements.txt` with a few minor changes +# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies +# are moved to the appropriate distribution. +COMMON_DEPENDENCIES = [ + "accelerate", + "black==24.4.2", + "blobfile", + "codeshield", + "fairscale", + "fastapi", + "fire", + "flake8", + "httpx", + "huggingface-hub", + "hydra-core", + "hydra-zen", + "json-strong-typing", + "llama-models", + "omegaconf", + "pandas", + "Pillow", + "pydantic==1.10.13", + "pydantic_core==2.18.2", + "python-openapi", + "requests", + "tiktoken", + "torch", + "transformers", + "uvicorn", +] -def all_registered_distributions() -> List[Distribution]: +def available_distributions() -> List[Distribution]: + inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()} + return [ Distribution( name="local-source", description="Use code from `llama_toolchain` itself to serve all llama stack APIs", - pip_packages=[], - config_defaults=DistributionConfigDefaults( - inference={ - "max_seq_len": 4096, - "max_batch_size": 1, - }, - safety={}, - ), + additional_pip_packages=COMMON_DEPENDENCIES, + adapters={ + ApiSurface.inference: inference_adapters_by_id["meta-reference"], + }, ), Distribution( name="local-ollama", description="Like local-source, but use ollama for running LLM inference", - pip_packages=["ollama"], - config_defaults=DistributionConfigDefaults( - inference={}, - safety={}, - ), + additional_pip_packages=COMMON_DEPENDENCIES, + adapters={ + ApiSurface.inference: inference_adapters_by_id["meta-ollama"], + }, ), ] diff --git a/llama_toolchain/inference/adapters.py b/llama_toolchain/inference/adapters.py new file mode 100644 index 000000000..5b1b5c873 --- /dev/null +++ b/llama_toolchain/inference/adapters.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter + + +def available_inference_adapters() -> List[Adapter]: + return [ + SourceAdapter( + api_surface=ApiSurface.inference, + adapter_id="meta-reference", + pip_packages=[ + "torch", + "zmq", + ], + module="llama_toolchain.inference.inference", + config_class="llama_toolchain.inference.inference.InlineImplConfig", + ), + SourceAdapter( + api_surface=ApiSurface.inference, + adapter_id="meta-ollama", + pip_packages=[ + "ollama", + ], + module="llama_toolchain.inference.ollama", + config_class="llama_toolchain.inference.ollama.OllamaImplConfig", + ), + ] diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index d7211ae65..61742a509 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -16,8 +16,8 @@ from .api.datatypes import ( ToolCallParseStatus, ) from .api.endpoints import ( - ChatCompletionResponse, ChatCompletionRequest, + ChatCompletionResponse, ChatCompletionResponseStreamChunk, CompletionRequest, Inference, @@ -25,6 +25,13 @@ from .api.endpoints import ( from .model_parallel import LlamaModelParallelGenerator +def get_adapter_impl(config: InlineImplConfig) -> Inference: + assert isinstance( + config, InlineImplConfig + ), f"Unexpected config type: {type(config)}" + return InferenceImpl(config) + + class InferenceImpl(Inference): def __init__(self, config: InlineImplConfig) -> None: diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py index 91727fd62..d07d71829 100644 --- a/llama_toolchain/inference/ollama.py +++ b/llama_toolchain/inference/ollama.py @@ -1,9 +1,14 @@ -import httpx +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import uuid from typing import AsyncGenerator -from ollama import AsyncClient +import httpx from llama_models.llama3_1.api.datatypes import ( BuiltinTool, @@ -14,6 +19,8 @@ from llama_models.llama3_1.api.datatypes import ( ) from llama_models.llama3_1.api.tool_utils import ToolUtils +from ollama import AsyncClient + from .api.config import OllamaImplConfig from .api.datatypes import ( ChatCompletionResponseEvent, @@ -22,14 +29,20 @@ from .api.datatypes import ( ToolCallParseStatus, ) from .api.endpoints import ( - ChatCompletionResponse, ChatCompletionRequest, + ChatCompletionResponse, ChatCompletionResponseStreamChunk, CompletionRequest, Inference, ) +def get_adapter_impl(config: OllamaImplConfig) -> Inference: + assert isinstance( + config, OllamaImplConfig + ), f"Unexpected config type: {type(config)}" + return OllamaInference(config) + class OllamaInference(Inference): @@ -41,9 +54,13 @@ class OllamaInference(Inference): self.client = AsyncClient(host=self.config.url) try: status = await self.client.pull(self.model) - assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama" + assert ( + status["status"] == "success" + ), f"Failed to pull model {self.model} in ollama" except httpx.ConnectError: - print("Ollama Server is not running, start it using `ollama serve` in a separate terminal") + print( + "Ollama Server is not running, start it using `ollama serve` in a separate terminal" + ) raise async def shutdown(self) -> None: @@ -55,9 +72,7 @@ class OllamaInference(Inference): def _messages_to_ollama_messages(self, messages: list[Message]) -> list: ollama_messages = [] for message in messages: - ollama_messages.append( - {"role": message.role, "content": message.content} - ) + ollama_messages.append({"role": message.role, "content": message.content}) return ollama_messages @@ -67,16 +82,16 @@ class OllamaInference(Inference): model=self.model, messages=self._messages_to_ollama_messages(request.messages), stream=False, - #TODO: add support for options like temp, top_p, max_seq_length, etc + # TODO: add support for options like temp, top_p, max_seq_length, etc ) - if r['done']: - if r['done_reason'] == 'stop': + if r["done"]: + if r["done_reason"] == "stop": stop_reason = StopReason.end_of_turn - elif r['done_reason'] == 'length': + elif r["done_reason"] == "length": stop_reason = StopReason.out_of_tokens completion_message = decode_assistant_message_from_content( - r['message']['content'], + r["message"]["content"], stop_reason, ) yield ChatCompletionResponse( @@ -94,7 +109,7 @@ class OllamaInference(Inference): stream = await self.client.chat( model=self.model, messages=self._messages_to_ollama_messages(request.messages), - stream=True + stream=True, ) buffer = "" @@ -103,14 +118,14 @@ class OllamaInference(Inference): async for chunk in stream: # check if ollama is done - if chunk['done']: - if chunk['done_reason'] == 'stop': + if chunk["done"]: + if chunk["done_reason"] == "stop": stop_reason = StopReason.end_of_turn - elif chunk['done_reason'] == 'length': + elif chunk["done_reason"] == "length": stop_reason = StopReason.out_of_tokens break - text = chunk['message']['content'] + text = chunk["message"]["content"] # check if its a tool call ( aka starts with <|python_tag|> ) if not ipython and text.startswith("<|python_tag|>"): @@ -197,7 +212,7 @@ class OllamaInference(Inference): ) -#TODO: Consolidate this with impl in llama-models +# TODO: Consolidate this with impl in llama-models def decode_assistant_message_from_content( content: str, stop_reason: StopReason, diff --git a/llama_toolchain/utils.py b/llama_toolchain/utils.py index 2bf3be4e3..d0805d901 100644 --- a/llama_toolchain/utils.py +++ b/llama_toolchain/utils.py @@ -6,6 +6,7 @@ import getpass import os +from pathlib import Path from typing import Optional from hydra import compose, initialize, MissingConfigException @@ -16,6 +17,8 @@ from omegaconf import OmegaConf LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/") +DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions" + def get_root_directory(): current_dir = os.path.dirname(os.path.abspath(__file__))