diff --git a/llama_toolchain/agentic_system/providers.py b/llama_toolchain/agentic_system/providers.py index e2d1e6424..79e66d15e 100644 --- a/llama_toolchain/agentic_system/providers.py +++ b/llama_toolchain/agentic_system/providers.py @@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.agentic_system, - provider_type="meta-reference", + provider_id="meta-reference", pip_packages=[ "codeshield", "matplotlib", diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 5fa105048..952f9b50e 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -11,16 +11,15 @@ from pathlib import Path import pkg_resources import yaml +from termcolor import cprint + from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR from llama_toolchain.common.exec import run_with_pty -from termcolor import cprint from llama_toolchain.core.datatypes import * # noqa: F403 import os -from termcolor import cprint - class StackConfigure(Subcommand): """Llama cli for configuring llama toolchain configs""" @@ -109,7 +108,7 @@ class StackConfigure(Subcommand): api2providers = build_config.distribution_spec.providers stub_config = { - api_str: {"provider_type": provider} + api_str: {"provider_id": provider} for api_str, provider in api2providers.items() } diff --git a/llama_toolchain/cli/stack/list_providers.py b/llama_toolchain/cli/stack/list_providers.py index fdf4ab054..a5640677d 100644 --- a/llama_toolchain/cli/stack/list_providers.py +++ b/llama_toolchain/cli/stack/list_providers.py @@ -49,7 +49,7 @@ class StackListProviders(Subcommand): for spec in providers_for_api.values(): rows.append( [ - spec.provider_type, + spec.provider_id, ",".join(spec.pip_packages), ] ) diff --git a/llama_toolchain/core/configure.py b/llama_toolchain/core/configure.py index 252358a52..7f9aa0140 100644 --- a/llama_toolchain/core/configure.py +++ b/llama_toolchain/core/configure.py @@ -21,14 +21,14 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None: for api_str, stub_config in existing_configs.items(): api = Api(api_str) providers = all_providers[api] - provider_type = stub_config["provider_type"] - if provider_type not in providers: + provider_id = stub_config["provider_id"] + if provider_id not in providers: raise ValueError( - f"Unknown provider `{provider_type}` is not available for API `{api_str}`" + f"Unknown provider `{provider_id}` is not available for API `{api_str}`" ) - provider_spec = providers[provider_type] - cprint(f"Configuring API: {api_str} ({provider_type})", "white", attrs=["bold"]) + provider_spec = providers[provider_id] + cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"]) config_type = instantiate_class_type(provider_spec.config_class) try: @@ -43,7 +43,7 @@ def configure_api_providers(existing_configs: Dict[str, Any]) -> None: print("") provider_configs[api_str] = { - "provider_type": provider_type, + "provider_id": provider_id, **provider_config.dict(), } diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py index 4549e1819..1366eeb0d 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_toolchain/core/datatypes.py @@ -32,7 +32,7 @@ class ApiEndpoint(BaseModel): @json_schema_type class ProviderSpec(BaseModel): api: Api - provider_type: str + provider_id: str config_class: str = Field( ..., description="Fully-qualified classname of the config for this provider", @@ -101,7 +101,7 @@ class RemoteProviderConfig(BaseModel): return url.rstrip("/") -def remote_provider_type(adapter_id: str) -> str: +def remote_provider_id(adapter_id: str) -> str: return f"remote::{adapter_id}" @@ -142,10 +142,10 @@ def remote_provider_spec( if adapter and adapter.config_class else "llama_toolchain.core.datatypes.RemoteProviderConfig" ) - provider_type = remote_provider_type(adapter.adapter_id) if adapter else "remote" + provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" return RemoteProviderSpec( - api=api, provider_type=provider_type, config_class=config_class, adapter=adapter + api=api, provider_id=provider_id, config_class=config_class, adapter=adapter ) diff --git a/llama_toolchain/core/distribution.py b/llama_toolchain/core/distribution.py index 0a968c422..dc81b53f1 100644 --- a/llama_toolchain/core/distribution.py +++ b/llama_toolchain/core/distribution.py @@ -14,14 +14,7 @@ from llama_toolchain.memory.api import Memory from llama_toolchain.safety.api import Safety from llama_toolchain.telemetry.api import Telemetry -from .datatypes import ( - Api, - ApiEndpoint, - DistributionSpec, - InlineProviderSpec, - ProviderSpec, - remote_provider_spec, -) +from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec # These are the dependencies needed by the distribution server. # `llama-toolchain` is automatically installed by the installation script. @@ -77,7 +70,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: module = importlib.import_module(f"llama_toolchain.{name}.providers") ret[api] = { "remote": remote_provider_spec(api), - **{a.provider_type: a for a in module.available_providers()}, + **{a.provider_id: a for a in module.available_providers()}, } return ret diff --git a/llama_toolchain/core/server.py b/llama_toolchain/core/server.py index b0ec75fe5..7082ec765 100644 --- a/llama_toolchain/core/server.py +++ b/llama_toolchain/core/server.py @@ -309,13 +309,13 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): for api_str, provider_config in config["providers"].items(): api = Api(api_str) providers = all_providers[api] - provider_type = provider_config["provider_type"] - if provider_type not in providers: + provider_id = provider_config["provider_id"] + if provider_id not in providers: raise ValueError( - f"Unknown provider `{provider_type}` is not available for API `{api}`" + f"Unknown provider `{provider_id}` is not available for API `{api}`" ) - provider_specs[api] = providers[provider_type] + provider_specs[api] = providers[provider_id] impls = resolve_impls(provider_specs, config) if Api.telemetry in impls: diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index 31c6b8d7b..928c6ef57 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, - provider_type="meta-reference", + provider_id="meta-reference", pip_packages=[ "accelerate", "blobfile", diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index 40a11235b..d3336278a 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -20,7 +20,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.memory, - provider_type="meta-reference-faiss", + provider_id="meta-reference-faiss", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_toolchain.memory.meta_reference.faiss", config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig", diff --git a/llama_toolchain/safety/providers.py b/llama_toolchain/safety/providers.py index 0db454ef3..c523e628e 100644 --- a/llama_toolchain/safety/providers.py +++ b/llama_toolchain/safety/providers.py @@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.safety, - provider_type="meta-reference", + provider_id="meta-reference", pip_packages=[ "accelerate", "codeshield", diff --git a/llama_toolchain/telemetry/providers.py b/llama_toolchain/telemetry/providers.py index 7b04145b3..00038e569 100644 --- a/llama_toolchain/telemetry/providers.py +++ b/llama_toolchain/telemetry/providers.py @@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.telemetry, - provider_type="console", + provider_id="console", pip_packages=[], module="llama_toolchain.telemetry.console", config_class="llama_toolchain.telemetry.console.ConsoleConfig",