Rename Distribution -> DistributionSpec, simplify RemoteProviders

This commit is contained in:
Ashwin Bharambe 2024-08-06 10:45:06 -07:00
parent 0a67f3d3e6
commit 7cc0445517
9 changed files with 181 additions and 147 deletions

View file

@ -8,7 +8,7 @@ import argparse
import json
import shlex
from pathlib import Path
from typing import Any, Dict
import yaml
from termcolor import cprint
@ -32,70 +32,67 @@ class DistributionConfigure(Subcommand):
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to configure",
default="local-source",
choices=[d.name for d in available_distributions()],
required=True,
)
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.registry import resolve_distribution_spec
dist = resolve_distribution(args.name)
if dist is None:
self.parser.error(f"Could not find distribution {args.name}")
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
)
return
env_file = DISTRIBS_BASE_DIR / dist.name / "conda.env"
# read this file to get the conda env name
assert env_file.exists(), f"Could not find conda env file {env_file}"
with open(env_file, "r") as f:
conda_env = f.read().strip()
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
configure_llama_distribution(dist, conda_env)
dist = resolve_distribution_spec(config["spec"])
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
configure_llama_distribution(dist, config)
def configure_llama_distribution(dist: "Distribution", conda_env: str):
def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.datatypes import RemoteProviderSpec
from llama_toolchain.distribution.dynamic import instantiate_class_type
python_exe = run_command(shlex.split("which python"))
# simple check
conda_env = config["conda_env"]
if conda_env not in python_exe:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
)
existing_config = None
config_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
if config_path.exists():
existing_config = config
if "providers" in existing_config:
cprint(
f"Configuration already exists for {dist.name}. Will overwrite...",
f"Configuration already exists for {config['name']}. Will overwrite...",
"yellow",
attrs=["bold"],
)
with open(config_path, "r") as fp:
existing_config = yaml.safe_load(fp)
provider_configs = {}
for api, provider_spec in dist.provider_specs.items():
if isinstance(provider_spec, RemoteProviderSpec):
provider_configs[api.value] = provider_spec.dict()
else:
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
config = prompt_for_config(
config_type,
(
config_type(**existing_config["providers"][api.value])
if existing_config and api.value in existing_config["providers"]
if existing_config
and "providers" in existing_config
and api.value in existing_config["providers"]
else None
),
)
@ -106,9 +103,10 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
dist_config = {
"providers": provider_configs,
"conda_env": conda_env,
**existing_config,
}
config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml"
with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False))

View file

@ -34,9 +34,9 @@ class DistributionCreate(Subcommand):
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.registry import resolve_distribution_spec
dist = resolve_distribution(args.name)
dist = resolve_distribution_spec(args.name)
if dist is not None:
self.parser.error(f"Distribution with name {args.name} already exists")
return

View file

@ -10,6 +10,7 @@ import shlex
import textwrap
import pkg_resources
import yaml
from termcolor import cprint
@ -32,26 +33,31 @@ class DistributionInstall(Subcommand):
self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.distribution.registry import available_distribution_specs
self.parser.add_argument(
"--spec",
type=str,
help="Distribution spec to install (try ollama-inline)",
required=True,
choices=[d.spec_id for d in available_distribution_specs()],
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to install -- (try local-ollama)",
help="What should the installation be called locally?",
required=True,
choices=[d.name for d in available_distributions()],
)
self.parser.add_argument(
"--conda-env",
type=str,
help="Specify the name of the conda environment you would like to create or update",
required=True,
help="conda env in which this distribution will run (default = distribution name)",
)
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.registry import resolve_distribution_spec
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
@ -59,25 +65,36 @@ class DistributionInstall(Subcommand):
"distribution/install_distribution.sh",
)
dist = resolve_distribution(args.name)
dist = resolve_distribution_spec(args.spec)
if dist is None:
self.parser.error(f"Could not find distribution {args.name}")
self.parser.error(f"Could not find distribution {args.spec}")
return
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
distrib_dir = DISTRIBS_BASE_DIR / args.name
os.makedirs(distrib_dir, exist_ok=True)
deps = distribution_dependencies(dist)
return_code = run_with_pty([script, args.conda_env, " ".join(deps)])
if not args.conda_env:
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
return_code = run_with_pty([script, conda_env, " ".join(deps)])
assert return_code == 0, cprint(
f"Failed to install distribution {dist.name}", color="red"
f"Failed to install distribution {dist.spec_id}", color="red"
)
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
f.write(f"{args.conda_env}\n")
config_file = distrib_dir / "config.yaml"
with open(config_file, "w") as f:
c = {
"conda_env": conda_env,
"spec": dist.spec_id,
"name": args.name,
}
f.write(yaml.dump(c))
cprint(
f"Distribution `{dist.name}` has been installed successfully!",
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
color="green",
)
print(
@ -85,8 +102,7 @@ class DistributionInstall(Subcommand):
f"""
Update your conda environment and configure this distribution by running:
conda deactivate && conda activate {args.conda_env}
llama distribution configure --name {dist.name}
conda deactivate && conda activate {conda_env}
llama distribution configure --name {args.name}
"""
)
)
))

View file

@ -28,23 +28,23 @@ class DistributionList(Subcommand):
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.distribution.registry import available_distribution_specs
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Name",
"Spec ID",
"ProviderSpecs",
"Description",
]
rows = []
for dist in available_distributions():
providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()}
for spec in available_distribution_specs():
providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()}
rows.append(
[
dist.name,
spec.spec_id,
json.dumps(providers, indent=2),
dist.description,
spec.description,
]
)
print_table(

View file

@ -6,7 +6,6 @@
import argparse
import shlex
from pathlib import Path
import yaml
@ -49,22 +48,23 @@ class DistributionStart(Subcommand):
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_command
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.registry import resolve_distribution_spec
from llama_toolchain.distribution.server import main as distribution_server_init
dist = resolve_distribution(args.name)
if dist is None:
self.parser.error(f"Distribution with name {args.name} not found")
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
)
return
config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
if not config_yaml.exists():
raise ValueError(
f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first"
)
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
with open(config_yaml, "r") as fp:
config = yaml.safe_load(fp)
dist = resolve_distribution_spec(config["spec"])
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
conda_env = config["conda_env"]
@ -76,8 +76,7 @@ class DistributionStart(Subcommand):
)
distribution_server_init(
dist.name,
config_yaml,
config_file,
args.port,
disable_ipv6=args.disable_ipv6,
)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Dict, List
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
@ -29,6 +29,10 @@ class ApiEndpoint(BaseModel):
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
@ -45,23 +49,21 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
class RemoteProviderConfig(BaseModel):
base_url: str = Field(..., description="The base URL for the llama stack provider")
api_key: Optional[str] = Field(
..., description="API key, if needed, for the provider"
)
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
module: str = Field(
...,
description="""
@ -69,10 +71,12 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
""",
)
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
class Distribution(BaseModel):
name: str
@json_schema_type
class DistributionSpec(BaseModel):
spec_id: str
description: str
provider_specs: Dict[Api, ProviderSpec] = Field(
@ -84,3 +88,12 @@ class Distribution(BaseModel):
default_factory=list,
description="Additional pip packages beyond those required by the providers",
)
@json_schema_type
class InstalledDistribution(BaseModel):
"""References to a installed / configured DistributionSpec"""
name: str
spec_id: str
# This is the class which represents the configs written by `configure`

View file

@ -8,13 +8,22 @@ import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.api.endpoints import Safety
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
from .datatypes import (
Api,
ApiEndpoint,
DistributionSpec,
InlineProviderSpec,
ProviderSpec,
)
def distribution_dependencies(distribution: Distribution) -> List[str]:
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
@ -51,3 +60,19 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
return {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
}

View file

@ -7,12 +7,8 @@
from functools import lru_cache
from typing import List, Optional
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import Api, Distribution, RemoteProviderSpec
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
from .distribution import api_providers
# This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
@ -49,39 +45,30 @@ def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def remote(api: Api, port: int) -> RemoteProviderSpec:
def remote_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_id=f"{api.value}-remote",
base_url=f"http://localhost:{port}",
module=client_module(api),
)
@lru_cache()
def available_distributions() -> List[Distribution]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
def available_distribution_specs() -> List[DistributionSpec]:
providers = api_providers()
return [
Distribution(
name="local-inline",
DistributionSpec(
spec_id="inline",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES,
provider_specs={
Api.inference: inference_providers_by_id["meta-reference"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
Api.inference: providers[Api.inference]["meta-reference"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
# NOTE: this hardcodes the ports to which things point to
Distribution(
name="full-remote",
DistributionSpec(
spec_id="remote",
description="Point to remote services for all llama stack APIs",
additional_pip_packages=[
"python-dotenv",
@ -97,28 +84,24 @@ def available_distributions() -> List[Distribution]:
"pydantic_core==2.18.2",
"uvicorn",
],
provider_specs={
Api.inference: remote(Api.inference, 5001),
Api.safety: remote(Api.safety, 5001),
Api.agentic_system: remote(Api.agentic_system, 5001),
},
provider_specs={x: remote_spec(x) for x in providers},
),
Distribution(
name="local-ollama",
DistributionSpec(
spec_id="ollama-inline",
description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES,
provider_specs={
Api.inference: inference_providers_by_id["meta-ollama"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
]
@lru_cache()
def resolve_distribution(name: str) -> Optional[Distribution]:
for dist in available_distributions():
if dist.name == name:
return dist
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.spec_id == spec_id:
return spec
return None

View file

@ -36,11 +36,11 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution
from .registry import resolve_distribution_spec
load_dotenv()
@ -250,7 +250,7 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
@ -265,7 +265,7 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_spec.base_url.rstrip("/")
provider_spec, provider_config.base_url.rstrip("/")
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
@ -275,16 +275,15 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
return impls
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
dist = resolve_distribution(dist_name)
if dist is None:
raise ValueError(f"Could not find distribution {dist_name}")
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI()
all_endpoints = api_endpoints()
@ -293,14 +292,15 @@ def main(
for provider_spec in dist.provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints:
url = provider_spec.base_url.rstrip("/") + endpoint.route
url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already