mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Rename Distribution -> DistributionSpec, simplify RemoteProviders
This commit is contained in:
parent
0a67f3d3e6
commit
7cc0445517
9 changed files with 181 additions and 147 deletions
|
@ -8,7 +8,7 @@ import argparse
|
||||||
import json
|
import json
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
from pathlib import Path
|
from typing import Any, Dict
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -32,83 +32,81 @@ class DistributionConfigure(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
|
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
from llama_toolchain.distribution.registry import available_distributions
|
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--name",
|
"--name",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the distribution to configure",
|
help="Name of the distribution to configure",
|
||||||
default="local-source",
|
required=True,
|
||||||
choices=[d.name for d in available_distributions()],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution
|
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||||
|
|
||||||
dist = resolve_distribution(args.name)
|
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
|
||||||
if dist is None:
|
if not config_file.exists():
|
||||||
self.parser.error(f"Could not find distribution {args.name}")
|
self.parser.error(
|
||||||
|
f"Could not find {config_file}. Please run `llama distribution install` first"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
env_file = DISTRIBS_BASE_DIR / dist.name / "conda.env"
|
# we need to find the spec from the name
|
||||||
# read this file to get the conda env name
|
with open(config_file, "r") as f:
|
||||||
assert env_file.exists(), f"Could not find conda env file {env_file}"
|
config = yaml.safe_load(f)
|
||||||
with open(env_file, "r") as f:
|
|
||||||
conda_env = f.read().strip()
|
|
||||||
|
|
||||||
configure_llama_distribution(dist, conda_env)
|
dist = resolve_distribution_spec(config["spec"])
|
||||||
|
if dist is None:
|
||||||
|
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
|
||||||
|
|
||||||
|
configure_llama_distribution(dist, config)
|
||||||
|
|
||||||
|
|
||||||
def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
|
||||||
from llama_toolchain.common.exec import run_command
|
from llama_toolchain.common.exec import run_command
|
||||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
from llama_toolchain.common.serialize import EnumEncoder
|
||||||
from llama_toolchain.distribution.datatypes import RemoteProviderSpec
|
|
||||||
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
||||||
|
|
||||||
python_exe = run_command(shlex.split("which python"))
|
python_exe = run_command(shlex.split("which python"))
|
||||||
# simple check
|
# simple check
|
||||||
|
conda_env = config["conda_env"]
|
||||||
if conda_env not in python_exe:
|
if conda_env not in python_exe:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Please re-run configure by activating the `{conda_env}` conda environment"
|
f"Please re-run configure by activating the `{conda_env}` conda environment"
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_config = None
|
existing_config = config
|
||||||
config_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
|
if "providers" in existing_config:
|
||||||
if config_path.exists():
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Configuration already exists for {dist.name}. Will overwrite...",
|
f"Configuration already exists for {config['name']}. Will overwrite...",
|
||||||
"yellow",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
with open(config_path, "r") as fp:
|
|
||||||
existing_config = yaml.safe_load(fp)
|
|
||||||
|
|
||||||
provider_configs = {}
|
provider_configs = {}
|
||||||
for api, provider_spec in dist.provider_specs.items():
|
for api, provider_spec in dist.provider_specs.items():
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
||||||
provider_configs[api.value] = provider_spec.dict()
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
else:
|
config = prompt_for_config(
|
||||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
config_type,
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
(
|
||||||
config = prompt_for_config(
|
config_type(**existing_config["providers"][api.value])
|
||||||
config_type,
|
if existing_config
|
||||||
(
|
and "providers" in existing_config
|
||||||
config_type(**existing_config["providers"][api.value])
|
and api.value in existing_config["providers"]
|
||||||
if existing_config and api.value in existing_config["providers"]
|
else None
|
||||||
else None
|
),
|
||||||
),
|
)
|
||||||
)
|
provider_configs[api.value] = {
|
||||||
provider_configs[api.value] = {
|
"provider_id": provider_spec.provider_id,
|
||||||
"provider_id": provider_spec.provider_id,
|
**config.dict(),
|
||||||
**config.dict(),
|
}
|
||||||
}
|
|
||||||
|
|
||||||
dist_config = {
|
dist_config = {
|
||||||
"providers": provider_configs,
|
"providers": provider_configs,
|
||||||
"conda_env": conda_env,
|
**existing_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml"
|
||||||
with open(config_path, "w") as fp:
|
with open(config_path, "w") as fp:
|
||||||
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
|
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
|
||||||
fp.write(yaml.dump(dist_config, sort_keys=False))
|
fp.write(yaml.dump(dist_config, sort_keys=False))
|
||||||
|
|
|
@ -34,9 +34,9 @@ class DistributionCreate(Subcommand):
|
||||||
# wants to pick and then ask for their configuration.
|
# wants to pick and then ask for their configuration.
|
||||||
|
|
||||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution
|
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||||
|
|
||||||
dist = resolve_distribution(args.name)
|
dist = resolve_distribution_spec(args.name)
|
||||||
if dist is not None:
|
if dist is not None:
|
||||||
self.parser.error(f"Distribution with name {args.name} already exists")
|
self.parser.error(f"Distribution with name {args.name} already exists")
|
||||||
return
|
return
|
||||||
|
|
|
@ -10,6 +10,7 @@ import shlex
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
import yaml
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -32,26 +33,31 @@ class DistributionInstall(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_distribution_install_cmd)
|
self.parser.set_defaults(func=self._run_distribution_install_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
from llama_toolchain.distribution.registry import available_distributions
|
from llama_toolchain.distribution.registry import available_distribution_specs
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--spec",
|
||||||
|
type=str,
|
||||||
|
help="Distribution spec to install (try ollama-inline)",
|
||||||
|
required=True,
|
||||||
|
choices=[d.spec_id for d in available_distribution_specs()],
|
||||||
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--name",
|
"--name",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the distribution to install -- (try local-ollama)",
|
help="What should the installation be called locally?",
|
||||||
required=True,
|
required=True,
|
||||||
choices=[d.name for d in available_distributions()],
|
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--conda-env",
|
"--conda-env",
|
||||||
type=str,
|
type=str,
|
||||||
help="Specify the name of the conda environment you would like to create or update",
|
help="conda env in which this distribution will run (default = distribution name)",
|
||||||
required=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
from llama_toolchain.common.exec import run_with_pty
|
||||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution
|
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||||
|
|
||||||
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
|
@ -59,25 +65,36 @@ class DistributionInstall(Subcommand):
|
||||||
"distribution/install_distribution.sh",
|
"distribution/install_distribution.sh",
|
||||||
)
|
)
|
||||||
|
|
||||||
dist = resolve_distribution(args.name)
|
dist = resolve_distribution_spec(args.spec)
|
||||||
if dist is None:
|
if dist is None:
|
||||||
self.parser.error(f"Could not find distribution {args.name}")
|
self.parser.error(f"Could not find distribution {args.spec}")
|
||||||
return
|
return
|
||||||
|
|
||||||
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
|
distrib_dir = DISTRIBS_BASE_DIR / args.name
|
||||||
|
os.makedirs(distrib_dir, exist_ok=True)
|
||||||
|
|
||||||
deps = distribution_dependencies(dist)
|
deps = distribution_dependencies(dist)
|
||||||
return_code = run_with_pty([script, args.conda_env, " ".join(deps)])
|
if not args.conda_env:
|
||||||
|
print(f"Using {args.name} as the Conda environment for this distribution")
|
||||||
|
|
||||||
|
conda_env = args.conda_env or args.name
|
||||||
|
return_code = run_with_pty([script, conda_env, " ".join(deps)])
|
||||||
|
|
||||||
assert return_code == 0, cprint(
|
assert return_code == 0, cprint(
|
||||||
f"Failed to install distribution {dist.name}", color="red"
|
f"Failed to install distribution {dist.spec_id}", color="red"
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
|
config_file = distrib_dir / "config.yaml"
|
||||||
f.write(f"{args.conda_env}\n")
|
with open(config_file, "w") as f:
|
||||||
|
c = {
|
||||||
|
"conda_env": conda_env,
|
||||||
|
"spec": dist.spec_id,
|
||||||
|
"name": args.name,
|
||||||
|
}
|
||||||
|
f.write(yaml.dump(c))
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Distribution `{dist.name}` has been installed successfully!",
|
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
|
@ -85,8 +102,7 @@ class DistributionInstall(Subcommand):
|
||||||
f"""
|
f"""
|
||||||
Update your conda environment and configure this distribution by running:
|
Update your conda environment and configure this distribution by running:
|
||||||
|
|
||||||
conda deactivate && conda activate {args.conda_env}
|
conda deactivate && conda activate {conda_env}
|
||||||
llama distribution configure --name {dist.name}
|
llama distribution configure --name {args.name}
|
||||||
"""
|
"""
|
||||||
)
|
))
|
||||||
)
|
|
||||||
|
|
|
@ -28,23 +28,23 @@ class DistributionList(Subcommand):
|
||||||
|
|
||||||
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_toolchain.cli.table import print_table
|
||||||
from llama_toolchain.distribution.registry import available_distributions
|
from llama_toolchain.distribution.registry import available_distribution_specs
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
"Name",
|
"Spec ID",
|
||||||
"ProviderSpecs",
|
"ProviderSpecs",
|
||||||
"Description",
|
"Description",
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for dist in available_distributions():
|
for spec in available_distribution_specs():
|
||||||
providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()}
|
providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()}
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
dist.name,
|
spec.spec_id,
|
||||||
json.dumps(providers, indent=2),
|
json.dumps(providers, indent=2),
|
||||||
dist.description,
|
spec.description,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print_table(
|
print_table(
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import shlex
|
import shlex
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -49,22 +48,23 @@ class DistributionStart(Subcommand):
|
||||||
|
|
||||||
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.common.exec import run_command
|
from llama_toolchain.common.exec import run_command
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution
|
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||||
from llama_toolchain.distribution.server import main as distribution_server_init
|
from llama_toolchain.distribution.server import main as distribution_server_init
|
||||||
|
|
||||||
dist = resolve_distribution(args.name)
|
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
|
||||||
if dist is None:
|
if not config_file.exists():
|
||||||
self.parser.error(f"Distribution with name {args.name} not found")
|
self.parser.error(
|
||||||
|
f"Could not find {config_file}. Please run `llama distribution install` first"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
|
# we need to find the spec from the name
|
||||||
if not config_yaml.exists():
|
with open(config_file, "r") as f:
|
||||||
raise ValueError(
|
config = yaml.safe_load(f)
|
||||||
f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(config_yaml, "r") as fp:
|
dist = resolve_distribution_spec(config["spec"])
|
||||||
config = yaml.safe_load(fp)
|
if dist is None:
|
||||||
|
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
|
||||||
|
|
||||||
conda_env = config["conda_env"]
|
conda_env = config["conda_env"]
|
||||||
|
|
||||||
|
@ -76,8 +76,7 @@ class DistributionStart(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
distribution_server_init(
|
distribution_server_init(
|
||||||
dist.name,
|
config_file,
|
||||||
config_yaml,
|
|
||||||
args.port,
|
args.port,
|
||||||
disable_ipv6=args.disable_ipv6,
|
disable_ipv6=args.disable_ipv6,
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from strong_typing.schema import json_schema_type
|
from strong_typing.schema import json_schema_type
|
||||||
|
@ -29,6 +29,10 @@ class ApiEndpoint(BaseModel):
|
||||||
class ProviderSpec(BaseModel):
|
class ProviderSpec(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
provider_id: str
|
provider_id: str
|
||||||
|
config_class: str = Field(
|
||||||
|
...,
|
||||||
|
description="Fully-qualified classname of the config for this provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -45,23 +49,21 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
config_class: str = Field(
|
|
||||||
...,
|
|
||||||
description="Fully-qualified classname of the config for this provider",
|
|
||||||
)
|
|
||||||
api_dependencies: List[Api] = Field(
|
api_dependencies: List[Api] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteProviderConfig(BaseModel):
|
||||||
|
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
..., description="API key, if needed, for the provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
|
||||||
headers: Dict[str, str] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Headers (e.g., authorization) to send with the request",
|
|
||||||
)
|
|
||||||
module: str = Field(
|
module: str = Field(
|
||||||
...,
|
...,
|
||||||
description="""
|
description="""
|
||||||
|
@ -69,10 +71,12 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||||
|
|
||||||
|
|
||||||
class Distribution(BaseModel):
|
@json_schema_type
|
||||||
name: str
|
class DistributionSpec(BaseModel):
|
||||||
|
spec_id: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
provider_specs: Dict[Api, ProviderSpec] = Field(
|
provider_specs: Dict[Api, ProviderSpec] = Field(
|
||||||
|
@ -84,3 +88,12 @@ class Distribution(BaseModel):
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Additional pip packages beyond those required by the providers",
|
description="Additional pip packages beyond those required by the providers",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class InstalledDistribution(BaseModel):
|
||||||
|
"""References to a installed / configured DistributionSpec"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
spec_id: str
|
||||||
|
# This is the class which represents the configs written by `configure`
|
||||||
|
|
|
@ -8,13 +8,22 @@ import inspect
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
||||||
|
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||||
from llama_toolchain.inference.api.endpoints import Inference
|
from llama_toolchain.inference.api.endpoints import Inference
|
||||||
|
from llama_toolchain.inference.providers import available_inference_providers
|
||||||
from llama_toolchain.safety.api.endpoints import Safety
|
from llama_toolchain.safety.api.endpoints import Safety
|
||||||
|
from llama_toolchain.safety.providers import available_safety_providers
|
||||||
|
|
||||||
from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
|
from .datatypes import (
|
||||||
|
Api,
|
||||||
|
ApiEndpoint,
|
||||||
|
DistributionSpec,
|
||||||
|
InlineProviderSpec,
|
||||||
|
ProviderSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
|
||||||
# only consider InlineProviderSpecs when calculating dependencies
|
# only consider InlineProviderSpecs when calculating dependencies
|
||||||
return [
|
return [
|
||||||
dep
|
dep
|
||||||
|
@ -51,3 +60,19 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
apis[api] = endpoints
|
apis[api] = endpoints
|
||||||
|
|
||||||
return apis
|
return apis
|
||||||
|
|
||||||
|
|
||||||
|
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
|
inference_providers_by_id = {
|
||||||
|
a.provider_id: a for a in available_inference_providers()
|
||||||
|
}
|
||||||
|
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
||||||
|
agentic_system_providers_by_id = {
|
||||||
|
a.provider_id: a for a in available_agentic_system_providers()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
Api.inference: inference_providers_by_id,
|
||||||
|
Api.safety: safety_providers_by_id,
|
||||||
|
Api.agentic_system: agentic_system_providers_by_id,
|
||||||
|
}
|
||||||
|
|
|
@ -7,12 +7,8 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
|
||||||
|
from .distribution import api_providers
|
||||||
from llama_toolchain.inference.providers import available_inference_providers
|
|
||||||
from llama_toolchain.safety.providers import available_safety_providers
|
|
||||||
|
|
||||||
from .datatypes import Api, Distribution, RemoteProviderSpec
|
|
||||||
|
|
||||||
# This is currently duplicated from `requirements.txt` with a few minor changes
|
# This is currently duplicated from `requirements.txt` with a few minor changes
|
||||||
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
|
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
|
||||||
|
@ -49,39 +45,30 @@ def client_module(api: Api) -> str:
|
||||||
return f"llama_toolchain.{api.value}.client"
|
return f"llama_toolchain.{api.value}.client"
|
||||||
|
|
||||||
|
|
||||||
def remote(api: Api, port: int) -> RemoteProviderSpec:
|
def remote_spec(api: Api) -> RemoteProviderSpec:
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_id=f"{api.value}-remote",
|
provider_id=f"{api.value}-remote",
|
||||||
base_url=f"http://localhost:{port}",
|
|
||||||
module=client_module(api),
|
module=client_module(api),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def available_distributions() -> List[Distribution]:
|
def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
inference_providers_by_id = {
|
providers = api_providers()
|
||||||
a.provider_id: a for a in available_inference_providers()
|
|
||||||
}
|
|
||||||
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
|
||||||
agentic_system_providers_by_id = {
|
|
||||||
a.provider_id: a for a in available_agentic_system_providers()
|
|
||||||
}
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Distribution(
|
DistributionSpec(
|
||||||
name="local-inline",
|
spec_id="inline",
|
||||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||||
provider_specs={
|
provider_specs={
|
||||||
Api.inference: inference_providers_by_id["meta-reference"],
|
Api.inference: providers[Api.inference]["meta-reference"],
|
||||||
Api.safety: safety_providers_by_id["meta-reference"],
|
Api.safety: providers[Api.safety]["meta-reference"],
|
||||||
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
# NOTE: this hardcodes the ports to which things point to
|
DistributionSpec(
|
||||||
Distribution(
|
spec_id="remote",
|
||||||
name="full-remote",
|
|
||||||
description="Point to remote services for all llama stack APIs",
|
description="Point to remote services for all llama stack APIs",
|
||||||
additional_pip_packages=[
|
additional_pip_packages=[
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
|
@ -97,28 +84,24 @@ def available_distributions() -> List[Distribution]:
|
||||||
"pydantic_core==2.18.2",
|
"pydantic_core==2.18.2",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
],
|
],
|
||||||
provider_specs={
|
provider_specs={x: remote_spec(x) for x in providers},
|
||||||
Api.inference: remote(Api.inference, 5001),
|
|
||||||
Api.safety: remote(Api.safety, 5001),
|
|
||||||
Api.agentic_system: remote(Api.agentic_system, 5001),
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
Distribution(
|
DistributionSpec(
|
||||||
name="local-ollama",
|
spec_id="ollama-inline",
|
||||||
description="Like local-source, but use ollama for running LLM inference",
|
description="Like local-source, but use ollama for running LLM inference",
|
||||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||||
provider_specs={
|
provider_specs={
|
||||||
Api.inference: inference_providers_by_id["meta-ollama"],
|
Api.inference: providers[Api.inference]["meta-ollama"],
|
||||||
Api.safety: safety_providers_by_id["meta-reference"],
|
Api.safety: providers[Api.safety]["meta-reference"],
|
||||||
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def resolve_distribution(name: str) -> Optional[Distribution]:
|
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
|
||||||
for dist in available_distributions():
|
for spec in available_distribution_specs():
|
||||||
if dist.name == name:
|
if spec.spec_id == spec_id:
|
||||||
return dist
|
return spec
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -36,11 +36,11 @@ from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
|
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||||
from .distribution import api_endpoints
|
from .distribution import api_endpoints
|
||||||
from .dynamic import instantiate_client, instantiate_provider
|
from .dynamic import instantiate_client, instantiate_provider
|
||||||
|
|
||||||
from .registry import resolve_distribution
|
from .registry import resolve_distribution_spec
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||||
provider_configs = config["providers"]
|
provider_configs = config["providers"]
|
||||||
provider_specs = topological_sort(dist.provider_specs.values())
|
provider_specs = topological_sort(dist.provider_specs.values())
|
||||||
|
|
||||||
|
@ -265,7 +265,7 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||||
provider_config = provider_configs[api.value]
|
provider_config = provider_configs[api.value]
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
impls[api] = instantiate_client(
|
impls[api] = instantiate_client(
|
||||||
provider_spec, provider_spec.base_url.rstrip("/")
|
provider_spec, provider_config.base_url.rstrip("/")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||||
|
@ -275,16 +275,15 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
|
||||||
):
|
|
||||||
dist = resolve_distribution(dist_name)
|
|
||||||
if dist is None:
|
|
||||||
raise ValueError(f"Could not find distribution {dist_name}")
|
|
||||||
|
|
||||||
with open(yaml_config, "r") as fp:
|
with open(yaml_config, "r") as fp:
|
||||||
config = yaml.safe_load(fp)
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
|
spec = config["spec"]
|
||||||
|
dist = resolve_distribution_spec(spec)
|
||||||
|
if dist is None:
|
||||||
|
raise ValueError(f"Could not find distribution specification `{spec}`")
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = api_endpoints()
|
||||||
|
@ -293,14 +292,15 @@ def main(
|
||||||
for provider_spec in dist.provider_specs.values():
|
for provider_spec in dist.provider_specs.values():
|
||||||
api = provider_spec.api
|
api = provider_spec.api
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
|
impl = impls[api]
|
||||||
|
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
url = provider_spec.base_url.rstrip("/") + endpoint.route
|
url = impl.base_url + endpoint.route
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
create_dynamic_passthrough(url)
|
create_dynamic_passthrough(url)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
impl = impls[api]
|
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
if not hasattr(impl, endpoint.name):
|
if not hasattr(impl, endpoint.name):
|
||||||
# ideally this should be a typing violation already
|
# ideally this should be a typing violation already
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue