getting closer to a distro definition, distro install + configure works

This commit is contained in:
Ashwin Bharambe 2024-08-01 22:59:11 -07:00
parent dac2b5a1ed
commit 041cafbee3
11 changed files with 471 additions and 130 deletions

View file

@ -5,18 +5,24 @@
# the root directory of this source tree.
import argparse
import os
import importlib
import inspect
import shlex
from pathlib import Path
from typing import Annotated, get_args, get_origin, Literal, Union
import pkg_resources
import yaml
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import all_registered_distributions
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command
CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
DISTRIBS = available_distributions()
class DistributionConfigure(Subcommand):
@ -34,59 +40,198 @@ class DistributionConfigure(Subcommand):
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
def _add_arguments(self):
distribs = all_registered_distributions()
self.parser.add_argument(
"--name",
type=str,
help="Mame of the distribution to configure",
default="local-source",
choices=[d.name for d in distribs],
choices=[d.name for d in available_distributions()],
)
def read_user_inputs(self):
checkpoint_dir = input(
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
)
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
default_conf_path = pkg_resources.resource_filename(
"llama_toolchain", "data/default_distribution_config.yaml"
)
with open(default_conf_path, "r") as f:
yaml_content = f.read()
yaml_content = yaml_content.format(
checkpoint_dir=checkpoint_dir,
model_parallel_size=model_parallel_size,
)
with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}")
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir)
dist = None
for d in DISTRIBS:
if d.name == args.name:
dist = d
break
assert (
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
), f"{checkpoint_dir} does not exist or it not a directory"
if dist is None:
self.parser.error(f"Could not find distribution {args.name}")
return
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml"
env_file = DISTRIBS_BASE_DIR / dist.name / "conda.env"
# read this file to get the conda env name
assert env_file.exists(), f"Could not find conda env file {env_file}"
with open(env_file, "r") as f:
conda_env = f.read().strip()
self.write_output_yaml(
checkpoint_dir,
model_parallel_size,
yaml_output_path,
configure_llama_distribution(dist, conda_env)
def configure_llama_distribution(dist: Distribution, conda_env: str):
python_exe = run_command(shlex.split("which python"))
# simple check
if conda_env not in python_exe:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
)
adapter_configs = {}
for api_surface, adapter in dist.adapters.items():
if isinstance(adapter, PassthroughApiAdapter):
adapter_configs[api_surface.value] = adapter.dict()
else:
cprint(
f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"]
)
config_type = instantiate_class_type(adapter.config_class)
# TODO: when we are re-configuring, we should read existing values
config = prompt_for_config(config_type)
adapter_configs[api_surface.value] = config.dict()
dist_config = {
"adapters": adapter_configs,
"conda_env": conda_env,
}
yaml_output_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
with open(yaml_output_path, "w") as fp:
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {yaml_output_path}")
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
return get_args(field.annotation)
return None
def is_optional(field_type):
"""Check if a field type is Optional."""
return get_origin(field_type) is Union and type(None) in get_args(field_type)
def get_non_none_type(field_type):
"""Get the non-None type from an Optional type."""
return next(arg for arg in get_args(field_type) if arg is not type(None))
def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
Args:
config_type: A Pydantic BaseModel class representing the configuration structure.
Returns:
An instance of the config_type with user-provided values.
"""
config_data = {}
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None
)
is_required = field.required
# Skip fields with Literal type
if get_origin(field_type) is Literal:
continue
# Check if the field is a discriminated union
if get_origin(field_type) is Annotated:
inner_type = get_args(field_type)[0]
if get_origin(inner_type) is Union:
discriminator = field.field_info.discriminator
if discriminator:
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
discriminator_value = input(
f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): "
)
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
sub_config = prompt_for_config(chosen_type)
config_data[field_name] = sub_config
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
break
else:
print(f"Invalid {discriminator}. Please try again.")
continue
if inspect.isclass(field_type) and issubclass(field_type, BaseModel):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(field_type)
else:
prompt = f"Enter value for {field_name}"
if default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
elif is_required:
prompt += " (required)"
prompt += ": "
while True:
user_input = input(prompt)
if user_input == "":
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type):
config_data[field_name] = None
break
elif not is_required:
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
continue
try:
# Handle Optional types
if is_optional(field_type):
if user_input.lower() == "none":
config_data[field_name] = None
break
field_type = get_non_none_type(field_type)
# Convert the input to the correct type
if inspect.isclass(field_type) and issubclass(
field_type, BaseModel
):
# For nested BaseModels, we assume a dictionary-like string input
import ast
config_data[field_name] = field_type(
**ast.literal_eval(user_input)
)
else:
config_data[field_name] = field_type(user_input)
break
except ValueError:
print(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
return config_type(**config_data)

View file

@ -7,6 +7,7 @@
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .configure import DistributionConfigure
from .create import DistributionCreate
from .install import DistributionInstall

View file

@ -7,20 +7,16 @@
import argparse
import os
import shlex
import subprocess
from pathlib import Path
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import all_registered_distributions
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
from llama_toolchain.distribution.datatypes import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command, run_with_pty
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
DISTRIBS = all_registered_distributions()
DISTRIBS = available_distributions()
class DistributionInstall(Subcommand):
@ -70,13 +66,19 @@ class DistributionInstall(Subcommand):
return
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
run_shell_script(script, args.conda_env, " ".join(dist.pip_packages))
deps = distribution_dependencies(dist)
run_command([script, args.conda_env, " ".join(deps)])
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
f.write(f"{args.conda_env}\n")
def run_shell_script(script_path, *args):
command_string = f"{script_path} {' '.join(shlex.quote(str(arg)) for arg in args)}"
command_list = shlex.split(command_string)
print(f"Running command: {command_list}")
subprocess.run(command_list, check=True, text=True)
# we need to run configure _within_ the conda environment and need to run with
# a pty since configure is
python_exe = run_command(
shlex.split(f"conda run -n {args.conda_env} which python")
).strip()
run_with_pty(
shlex.split(
f"{python_exe} -m llama_toolchain.cli.llama distribution configure --name {dist.name}"
)
)

View file

@ -9,7 +9,8 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.registry import all_registered_distributions
from llama_toolchain.distribution.datatypes import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions
class DistributionList(Subcommand):
@ -37,12 +38,13 @@ class DistributionList(Subcommand):
]
rows = []
for dist in all_registered_distributions():
for dist in available_distributions():
deps = distribution_dependencies(dist)
rows.append(
[
dist.name,
dist.description,
", ".join(dist.pip_packages),
", ".join(deps),
]
)
print_table(

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
import os
import pty
import select
import subprocess
import sys
import termios
import tty
def run_with_pty(command):
old_settings = termios.tcgetattr(sys.stdin)
# Create a new pseudo-terminal
master, slave = pty.openpty()
try:
# ensure the terminal does not echo input
tty.setraw(sys.stdin.fileno())
process = subprocess.Popen(
command,
stdin=slave,
stdout=slave,
stderr=slave,
universal_newlines=True,
)
# Close the slave file descriptor as it's now owned by the subprocess
os.close(slave)
def handle_io():
while True:
rlist, _, _ = select.select([sys.stdin, master], [], [])
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data: # EOF
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
os.write(sys.stdout.fileno(), data)
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Restore original terminal settings
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
process.wait()
os.close(master)
return process.returncode
def run_command(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
print(f"Error: {error.decode('utf-8')}")
sys.exit(1)
return output.decode("utf-8")

View file

@ -32,15 +32,7 @@ class PassthroughApiAdapterConfig(BaseModel):
@json_schema_type
class PythonImplAdapterConfig(BaseModel):
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
module: str = Field(..., description="The name of the module to import")
entrypoint: str = Field(
...,
description="The name of the entrypoint function which creates the implementation for the API",
)
adapter_id: str
kwargs: Dict[str, Any] = Field(
default_factory=dict, description="kwargs to pass to the entrypoint"
)
@ -63,21 +55,43 @@ AdapterConfig = Annotated[
]
class DistributionConfig(BaseModel):
inference: AdapterConfig
safety: AdapterConfig
# configs for each API that the stack provides, e.g.
# agentic_system: AdapterConfig
# post_training: AdapterConfig
@json_schema_type
class ApiSurface(Enum):
inference = "inference"
safety = "safety"
class DistributionConfigDefaults(BaseModel):
inference: Dict[str, Any] = Field(
default_factory=dict, description="Default kwargs for the inference adapter"
@json_schema_type
class Adapter(BaseModel):
api_surface: ApiSurface
adapter_id: str
@json_schema_type
class SourceAdapter(Adapter):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
safety: Dict[str, Any] = Field(
default_factory=dict, description="Default kwargs for the safety adapter"
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have
a `get_adapter_instance()` method which will be passed a validated config object
of type `config_class`.""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this adapter",
)
@json_schema_type
class PassthroughApiAdapter(Adapter):
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
@ -85,9 +99,22 @@ class Distribution(BaseModel):
name: str
description: str
# you must install the packages to get the functionality needed.
# later, we may have a docker image be the main artifact of
# a distribution.
pip_packages: List[str]
adapters: Dict[ApiSurface, Adapter] = Field(
default_factory=dict,
description="The API surfaces provided by this distribution",
)
config_defaults: DistributionConfigDefaults
additional_pip_packages: List[str] = Field(
default_factory=list,
description="Additional pip packages beyond those required by the adapters",
)
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
] + distribution.additional_pip_packages

View file

@ -6,30 +6,60 @@
from typing import List
from .datatypes import Distribution, DistributionConfigDefaults
from llama_toolchain.inference.adapters import available_inference_adapters
from .datatypes import ApiSurface, Distribution
# This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
# are moved to the appropriate distribution.
COMMON_DEPENDENCIES = [
"accelerate",
"black==24.4.2",
"blobfile",
"codeshield",
"fairscale",
"fastapi",
"fire",
"flake8",
"httpx",
"huggingface-hub",
"hydra-core",
"hydra-zen",
"json-strong-typing",
"llama-models",
"omegaconf",
"pandas",
"Pillow",
"pydantic==1.10.13",
"pydantic_core==2.18.2",
"python-openapi",
"requests",
"tiktoken",
"torch",
"transformers",
"uvicorn",
]
def all_registered_distributions() -> List[Distribution]:
def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
return [
Distribution(
name="local-source",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
pip_packages=[],
config_defaults=DistributionConfigDefaults(
inference={
"max_seq_len": 4096,
"max_batch_size": 1,
},
safety={},
),
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-reference"],
},
),
Distribution(
name="local-ollama",
description="Like local-source, but use ollama for running LLM inference",
pip_packages=["ollama"],
config_defaults=DistributionConfigDefaults(
inference={},
safety={},
),
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
},
),
]

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
def available_inference_adapters() -> List[Adapter]:
return [
SourceAdapter(
api_surface=ApiSurface.inference,
adapter_id="meta-reference",
pip_packages=[
"torch",
"zmq",
],
module="llama_toolchain.inference.inference",
config_class="llama_toolchain.inference.inference.InlineImplConfig",
),
SourceAdapter(
api_surface=ApiSurface.inference,
adapter_id="meta-ollama",
pip_packages=[
"ollama",
],
module="llama_toolchain.inference.ollama",
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
),
]

View file

@ -16,8 +16,8 @@ from .api.datatypes import (
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
@ -25,6 +25,13 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator
def get_adapter_impl(config: InlineImplConfig) -> Inference:
assert isinstance(
config, InlineImplConfig
), f"Unexpected config type: {type(config)}"
return InferenceImpl(config)
class InferenceImpl(Inference):
def __init__(self, config: InlineImplConfig) -> None:

View file

@ -1,9 +1,14 @@
import httpx
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import AsyncGenerator
from ollama import AsyncClient
import httpx
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
@ -14,6 +19,8 @@ from llama_models.llama3_1.api.datatypes import (
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from ollama import AsyncClient
from .api.config import OllamaImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
@ -22,14 +29,20 @@ from .api.datatypes import (
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
)
def get_adapter_impl(config: OllamaImplConfig) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"
return OllamaInference(config)
class OllamaInference(Inference):
@ -41,9 +54,13 @@ class OllamaInference(Inference):
self.client = AsyncClient(host=self.config.url)
try:
status = await self.client.pull(self.model)
assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama"
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
except httpx.ConnectError:
print("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
print(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
raise
async def shutdown(self) -> None:
@ -55,9 +72,7 @@ class OllamaInference(Inference):
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
for message in messages:
ollama_messages.append(
{"role": message.role, "content": message.content}
)
ollama_messages.append({"role": message.role, "content": message.content})
return ollama_messages
@ -67,16 +82,16 @@ class OllamaInference(Inference):
model=self.model,
messages=self._messages_to_ollama_messages(request.messages),
stream=False,
#TODO: add support for options like temp, top_p, max_seq_length, etc
# TODO: add support for options like temp, top_p, max_seq_length, etc
)
if r['done']:
if r['done_reason'] == 'stop':
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r['done_reason'] == 'length':
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = decode_assistant_message_from_content(
r['message']['content'],
r["message"]["content"],
stop_reason,
)
yield ChatCompletionResponse(
@ -94,7 +109,7 @@ class OllamaInference(Inference):
stream = await self.client.chat(
model=self.model,
messages=self._messages_to_ollama_messages(request.messages),
stream=True
stream=True,
)
buffer = ""
@ -103,14 +118,14 @@ class OllamaInference(Inference):
async for chunk in stream:
# check if ollama is done
if chunk['done']:
if chunk['done_reason'] == 'stop':
if chunk["done"]:
if chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif chunk['done_reason'] == 'length':
elif chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk['message']['content']
text = chunk["message"]["content"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
@ -197,7 +212,7 @@ class OllamaInference(Inference):
)
#TODO: Consolidate this with impl in llama-models
# TODO: Consolidate this with impl in llama-models
def decode_assistant_message_from_content(
content: str,
stop_reason: StopReason,

View file

@ -6,6 +6,7 @@
import getpass
import os
from pathlib import Path
from typing import Optional
from hydra import compose, initialize, MissingConfigException
@ -16,6 +17,8 @@ from omegaconf import OmegaConf
LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
def get_root_directory():
current_dir = os.path.dirname(os.path.abspath(__file__))