From 65a9e40174aaa7e572d180ce187ea816f6c3ab8a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 5 Aug 2024 13:26:29 -0700 Subject: [PATCH] Adapter -> Provider --- .../agentic_system/agentic_system.py | 4 +- .../{adapters.py => providers.py} | 10 ++-- llama_toolchain/cli/distribution/configure.py | 22 ++++---- llama_toolchain/cli/distribution/create.py | 2 +- llama_toolchain/cli/distribution/list.py | 6 +-- llama_toolchain/distribution/datatypes.py | 22 ++++---- llama_toolchain/distribution/distribution.py | 10 ++-- llama_toolchain/distribution/dynamic.py | 20 +++---- llama_toolchain/distribution/registry.py | 53 ++++++++++--------- llama_toolchain/distribution/server.py | 48 +++++++++-------- llama_toolchain/inference/inference.py | 6 ++- llama_toolchain/inference/ollama.py | 2 +- .../inference/{adapters.py => providers.py} | 12 ++--- .../safety/{adapters.py => providers.py} | 8 +-- llama_toolchain/safety/safety.py | 4 +- 15 files changed, 119 insertions(+), 110 deletions(-) rename llama_toolchain/agentic_system/{adapters.py => providers.py} (71%) rename llama_toolchain/inference/{adapters.py => providers.py} (72%) rename llama_toolchain/safety/{adapters.py => providers.py} (71%) diff --git a/llama_toolchain/agentic_system/agentic_system.py b/llama_toolchain/agentic_system/agentic_system.py index 8bf74e44f..81a4a3337 100644 --- a/llama_toolchain/agentic_system/agentic_system.py +++ b/llama_toolchain/agentic_system/agentic_system.py @@ -7,7 +7,7 @@ from llama_toolchain.agentic_system.api import AgenticSystem -from llama_toolchain.distribution.datatypes import Adapter, Api +from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import Inference from llama_toolchain.safety.api import Safety @@ -44,7 +44,7 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -async def get_adapter_impl(config: AgenticSystemConfig, deps: Dict[Api, Adapter]): +async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]): assert isinstance( config, AgenticSystemConfig ), f"Unexpected config type: {type(config)}" diff --git a/llama_toolchain/agentic_system/adapters.py b/llama_toolchain/agentic_system/providers.py similarity index 71% rename from llama_toolchain/agentic_system/adapters.py rename to llama_toolchain/agentic_system/providers.py index 82d1a0ccb..a1521fa46 100644 --- a/llama_toolchain/agentic_system/adapters.py +++ b/llama_toolchain/agentic_system/providers.py @@ -6,14 +6,14 @@ from typing import List -from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter +from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec -def available_agentic_system_adapters() -> List[Adapter]: +def available_agentic_system_providers() -> List[ProviderSpec]: return [ - SourceAdapter( + InlineProviderSpec( api=Api.agentic_system, - adapter_id="meta-reference", + provider_id="meta-reference", pip_packages=[ "codeshield", "torch", @@ -21,7 +21,7 @@ def available_agentic_system_adapters() -> List[Adapter]: ], module="llama_toolchain.agentic_system.agentic_system", config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig", - adapter_dependencies=[ + api_dependencies=[ Api.inference, Api.safety, ], diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index 2d0341754..c495eb9f9 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -63,7 +63,7 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str): from llama_toolchain.common.exec import run_command from llama_toolchain.common.prompt_for_config import prompt_for_config from llama_toolchain.common.serialize import EnumEncoder - from llama_toolchain.distribution.datatypes import PassthroughApiAdapter + from llama_toolchain.distribution.datatypes import RemoteProviderSpec from llama_toolchain.distribution.dynamic import instantiate_class_type python_exe = run_command(shlex.split("which python")) @@ -84,28 +84,28 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str): with open(config_path, "r") as fp: existing_config = yaml.safe_load(fp) - adapter_configs = {} - for api, adapter in dist.adapters.items(): - if isinstance(adapter, PassthroughApiAdapter): - adapter_configs[api.value] = adapter.dict() + provider_configs = {} + for api, provider_spec in dist.provider_specs.items(): + if isinstance(provider_spec, RemoteProviderSpec): + provider_configs[api.value] = provider_spec.dict() else: cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) - config_type = instantiate_class_type(adapter.config_class) + config_type = instantiate_class_type(provider_spec.config_class) config = prompt_for_config( config_type, ( - config_type(**existing_config["adapters"][api.value]) - if existing_config and api.value in existing_config["adapters"] + config_type(**existing_config["providers"][api.value]) + if existing_config and api.value in existing_config["providers"] else None ), ) - adapter_configs[api.value] = { - "adapter_id": adapter.adapter_id, + provider_configs[api.value] = { + "provider_id": provider_spec.provider_id, **config.dict(), } dist_config = { - "adapters": adapter_configs, + "providers": provider_configs, "conda_env": conda_env, } diff --git a/llama_toolchain/cli/distribution/create.py b/llama_toolchain/cli/distribution/create.py index e5a835c91..bb0d9b42e 100644 --- a/llama_toolchain/cli/distribution/create.py +++ b/llama_toolchain/cli/distribution/create.py @@ -30,7 +30,7 @@ class DistributionCreate(Subcommand): required=True, ) # for each Api the user wants to support, we should - # get the list of available adapters, ask which one the user + # get the list of available providers, ask which one the user # wants to pick and then ask for their configuration. def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: diff --git a/llama_toolchain/cli/distribution/list.py b/llama_toolchain/cli/distribution/list.py index 9f009f4c2..3d6b69186 100644 --- a/llama_toolchain/cli/distribution/list.py +++ b/llama_toolchain/cli/distribution/list.py @@ -33,17 +33,17 @@ class DistributionList(Subcommand): # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ "Name", - "Adapters", + "ProviderSpecs", "Description", ] rows = [] for dist in available_distributions(): - adapters = {k.value: v.adapter_id for k, v in dist.adapters.items()} + providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()} rows.append( [ dist.name, - json.dumps(adapters, indent=2), + json.dumps(providers, indent=2), dist.description, ] ) diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 7dd197a80..85dcdae81 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -26,13 +26,13 @@ class ApiEndpoint(BaseModel): @json_schema_type -class Adapter(BaseModel): +class ProviderSpec(BaseModel): api: Api - adapter_id: str + provider_id: str @json_schema_type -class SourceAdapter(Adapter): +class InlineProviderSpec(ProviderSpec): pip_packages: List[str] = Field( default_factory=list, description="The pip dependencies needed for this implementation", @@ -42,21 +42,21 @@ class SourceAdapter(Adapter): description=""" Fully-qualified name of the module to import. The module is expected to have: - - `get_adapter_impl(config, deps)`: returns the local implementation + - `get_provider_impl(config, deps)`: returns the local implementation """, ) config_class: str = Field( ..., - description="Fully-qualified classname of the config for this adapter", + description="Fully-qualified classname of the config for this provider", ) - adapter_dependencies: List[Api] = Field( + api_dependencies: List[Api] = Field( default_factory=list, - description="Higher-level API surfaces may depend on other adapters to provide their functionality", + description="Higher-level API surfaces may depend on other providers to provide their functionality", ) @json_schema_type -class PassthroughApiAdapter(Adapter): +class RemoteProviderSpec(ProviderSpec): base_url: str = Field(..., description="The base URL for the llama stack provider") headers: Dict[str, str] = Field( default_factory=dict, @@ -75,12 +75,12 @@ class Distribution(BaseModel): name: str description: str - adapters: Dict[Api, Adapter] = Field( + provider_specs: Dict[Api, ProviderSpec] = Field( default_factory=dict, - description="The API surfaces provided by this distribution", + description="Provider specifications for each of the APIs provided by this distribution", ) additional_pip_packages: List[str] = Field( default_factory=list, - description="Additional pip packages beyond those required by the adapters", + description="Additional pip packages beyond those required by the providers", ) diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index 27a7e4a5d..294f9bd4e 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -11,16 +11,16 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.safety.api.endpoints import Safety -from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter +from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec def distribution_dependencies(distribution: Distribution) -> List[str]: - # only consider SourceAdapters when calculating dependencies + # only consider InlineProviderSpecs when calculating dependencies return [ dep - for adapter in distribution.adapters.values() - if isinstance(adapter, SourceAdapter) - for dep in adapter.pip_packages + for provider_spec in distribution.provider_specs.values() + if isinstance(provider_spec, InlineProviderSpec) + for dep in provider_spec.pip_packages ] + distribution.additional_pip_packages diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py index ae7075940..20fa038bf 100644 --- a/llama_toolchain/distribution/dynamic.py +++ b/llama_toolchain/distribution/dynamic.py @@ -8,7 +8,7 @@ import asyncio import importlib from typing import Any, Dict -from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter +from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec def instantiate_class_type(fully_qualified_name): @@ -18,17 +18,19 @@ def instantiate_class_type(fully_qualified_name): # returns a class implementing the protocol corresponding to the Api -def instantiate_adapter( - adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter] +def instantiate_provider( + provider_spec: InlineProviderSpec, + provider_config: Dict[str, Any], + deps: Dict[str, ProviderSpec], ): - module = importlib.import_module(adapter.module) + module = importlib.import_module(provider_spec.module) - config_type = instantiate_class_type(adapter.config_class) - config = config_type(**adapter_config) - return asyncio.run(module.get_adapter_impl(config, deps)) + config_type = instantiate_class_type(provider_spec.config_class) + config = config_type(**provider_config) + return asyncio.run(module.get_provider_impl(config, deps)) -def instantiate_client(adapter: PassthroughApiAdapter, base_url: str): - module = importlib.import_module(adapter.module) +def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str): + module = importlib.import_module(provider_spec.module) return asyncio.run(module.get_client_impl(base_url)) diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 17fa4bc93..ea046f5a1 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -7,12 +7,12 @@ from functools import lru_cache from typing import List, Optional -from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters +from llama_toolchain.agentic_system.providers import available_agentic_system_providers -from llama_toolchain.inference.adapters import available_inference_adapters -from llama_toolchain.safety.adapters import available_safety_adapters +from llama_toolchain.inference.providers import available_inference_providers +from llama_toolchain.safety.providers import available_safety_providers -from .datatypes import Api, Distribution, PassthroughApiAdapter +from .datatypes import Api, Distribution, RemoteProviderSpec # This is currently duplicated from `requirements.txt` with a few minor changes # dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies @@ -49,10 +49,10 @@ def client_module(api: Api) -> str: return f"llama_toolchain.{api.value}.client" -def passthrough(api: Api, port: int) -> PassthroughApiAdapter: - return PassthroughApiAdapter( +def remote(api: Api, port: int) -> RemoteProviderSpec: + return RemoteProviderSpec( api=api, - adapter_id=f"{api.value}-passthrough", + provider_id=f"{api.value}-remote", base_url=f"http://localhost:{port}", module=client_module(api), ) @@ -60,25 +60,28 @@ def passthrough(api: Api, port: int) -> PassthroughApiAdapter: @lru_cache() def available_distributions() -> List[Distribution]: - inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()} - safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()} - agentic_system_adapters_by_id = { - a.adapter_id: a for a in available_agentic_system_adapters() + inference_providers_by_id = { + a.provider_id: a for a in available_inference_providers() + } + safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()} + agentic_system_providers_by_id = { + a.provider_id: a for a in available_agentic_system_providers() } return [ Distribution( - name="local-source", + name="local-inline", description="Use code from `llama_toolchain` itself to serve all llama stack APIs", additional_pip_packages=COMMON_DEPENDENCIES, - adapters={ - Api.inference: inference_adapters_by_id["meta-reference"], - Api.safety: safety_adapters_by_id["meta-reference"], - Api.agentic_system: agentic_system_adapters_by_id["meta-reference"], + provider_specs={ + Api.inference: inference_providers_by_id["meta-reference"], + Api.safety: safety_providers_by_id["meta-reference"], + Api.agentic_system: agentic_system_providers_by_id["meta-reference"], }, ), + # NOTE: this hardcodes the ports to which things point to Distribution( - name="full-passthrough", + name="full-remote", description="Point to remote services for all llama stack APIs", additional_pip_packages=[ "python-dotenv", @@ -94,20 +97,20 @@ def available_distributions() -> List[Distribution]: "pydantic_core==2.18.2", "uvicorn", ], - adapters={ - Api.inference: passthrough(Api.inference, 5001), - Api.safety: passthrough(Api.safety, 5001), - Api.agentic_system: passthrough(Api.agentic_system, 5001), + provider_specs={ + Api.inference: remote(Api.inference, 5001), + Api.safety: remote(Api.safety, 5001), + Api.agentic_system: remote(Api.agentic_system, 5001), }, ), Distribution( name="local-ollama", description="Like local-source, but use ollama for running LLM inference", additional_pip_packages=COMMON_DEPENDENCIES, - adapters={ - Api.inference: inference_adapters_by_id["meta-ollama"], - Api.safety: safety_adapters_by_id["meta-reference"], - Api.agentic_system: agentic_system_adapters_by_id["meta-reference"], + provider_specs={ + Api.inference: inference_providers_by_id["meta-ollama"], + Api.safety: safety_providers_by_id["meta-reference"], + Api.agentic_system: agentic_system_providers_by_id["meta-reference"], }, ), ] diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 857639551..4c61d1e40 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -36,9 +36,9 @@ from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError from termcolor import cprint -from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter +from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec from .distribution import api_endpoints -from .dynamic import instantiate_adapter, instantiate_client +from .dynamic import instantiate_client, instantiate_provider from .registry import resolve_distribution @@ -226,15 +226,15 @@ def create_dynamic_typed_route(func: Any): return endpoint -def topological_sort(adapters: List[Adapter]) -> List[Adapter]: +def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: - by_id = {x.api: x for x in adapters} + by_id = {x.api: x for x in providers} - def dfs(a: Adapter, visited: Set[Api], stack: List[Api]): + def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): visited.add(a.api) - if not isinstance(a, PassthroughApiAdapter): - for api in a.adapter_dependencies: + if not isinstance(a, RemoteProviderSpec): + for api in a.api_dependencies: if api not in visited: dfs(by_id[api], visited, stack) @@ -243,7 +243,7 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]: visited = set() stack = [] - for a in adapters: + for a in providers: if a.api not in visited: dfs(a, visited, stack) @@ -251,23 +251,25 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]: def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]: - adapter_configs = config["adapters"] - adapters = topological_sort(dist.adapters.values()) + provider_configs = config["providers"] + provider_specs = topological_sort(dist.provider_specs.values()) impls = {} - for adapter in adapters: - api = adapter.api - if api.value not in adapter_configs: + for provider_spec in provider_specs: + api = provider_spec.api + if api.value not in provider_configs: raise ValueError( - f"Could not find adapter config for {api}. Please add it to the config" + f"Could not find provider_spec config for {api}. Please add it to the config" ) - adapter_config = adapter_configs[api.value] - if isinstance(adapter, PassthroughApiAdapter): - impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/")) + provider_config = provider_configs[api.value] + if isinstance(provider_spec, RemoteProviderSpec): + impls[api] = instantiate_client( + provider_spec, provider_spec.base_url.rstrip("/") + ) else: - deps = {api: impls[api] for api in adapter.adapter_dependencies} - impl = instantiate_adapter(adapter, adapter_config, deps) + deps = {api: impls[api] for api in provider_spec.api_dependencies} + impl = instantiate_provider(provider_spec, provider_config, deps) impls[api] = impl return impls @@ -288,12 +290,12 @@ def main( all_endpoints = api_endpoints() impls = resolve_impls(dist, config) - for adapter in dist.adapters.values(): - api = adapter.api + for provider_spec in dist.provider_specs.values(): + api = provider_spec.api endpoints = all_endpoints[api] - if isinstance(adapter, PassthroughApiAdapter): + if isinstance(provider_spec, RemoteProviderSpec): for endpoint in endpoints: - url = adapter.base_url.rstrip("/") + endpoint.route + url = provider_spec.base_url.rstrip("/") + endpoint.route getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index 2dd15317e..194a0a882 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -11,7 +11,7 @@ from typing import AsyncIterator, Dict, Union from llama_models.llama3_1.api.datatypes import StopReason from llama_models.sku_list import resolve_model -from llama_toolchain.distribution.datatypes import Adapter, Api +from llama_toolchain.distribution.datatypes import Api, ProviderSpec from .api.config import MetaReferenceImplConfig from .api.datatypes import ( @@ -29,7 +29,9 @@ from .api.endpoints import ( from .model_parallel import LlamaModelParallelGenerator -async def get_adapter_impl(config: MetaReferenceImplConfig, _deps: Dict[Api, Adapter]): +async def get_provider_impl( + config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec] +): assert isinstance( config, MetaReferenceImplConfig ), f"Unexpected config type: {type(config)}" diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py index d07d71829..560960f8b 100644 --- a/llama_toolchain/inference/ollama.py +++ b/llama_toolchain/inference/ollama.py @@ -37,7 +37,7 @@ from .api.endpoints import ( ) -def get_adapter_impl(config: OllamaImplConfig) -> Inference: +def get_provider_impl(config: OllamaImplConfig) -> Inference: assert isinstance( config, OllamaImplConfig ), f"Unexpected config type: {type(config)}" diff --git a/llama_toolchain/inference/adapters.py b/llama_toolchain/inference/providers.py similarity index 72% rename from llama_toolchain/inference/adapters.py rename to llama_toolchain/inference/providers.py index 320bad9a7..a12defafa 100644 --- a/llama_toolchain/inference/adapters.py +++ b/llama_toolchain/inference/providers.py @@ -6,14 +6,14 @@ from typing import List -from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter +from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec -def available_inference_adapters() -> List[Adapter]: +def available_inference_providers() -> List[ProviderSpec]: return [ - SourceAdapter( + InlineProviderSpec( api=Api.inference, - adapter_id="meta-reference", + provider_id="meta-reference", pip_packages=[ "torch", "zmq", @@ -21,9 +21,9 @@ def available_inference_adapters() -> List[Adapter]: module="llama_toolchain.inference.inference", config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig", ), - SourceAdapter( + InlineProviderSpec( api=Api.inference, - adapter_id="meta-ollama", + provider_id="meta-ollama", pip_packages=[ "ollama", ], diff --git a/llama_toolchain/safety/adapters.py b/llama_toolchain/safety/providers.py similarity index 71% rename from llama_toolchain/safety/adapters.py rename to llama_toolchain/safety/providers.py index 6411da757..4a88c8e28 100644 --- a/llama_toolchain/safety/adapters.py +++ b/llama_toolchain/safety/providers.py @@ -6,14 +6,14 @@ from typing import List -from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter +from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec -def available_safety_adapters() -> List[Adapter]: +def available_safety_providers() -> List[ProviderSpec]: return [ - SourceAdapter( + InlineProviderSpec( api=Api.safety, - adapter_id="meta-reference", + provider_id="meta-reference", pip_packages=[ "codeshield", "torch", diff --git a/llama_toolchain/safety/safety.py b/llama_toolchain/safety/safety.py index 5a01cc2c0..3f1c7698c 100644 --- a/llama_toolchain/safety/safety.py +++ b/llama_toolchain/safety/safety.py @@ -8,7 +8,7 @@ import asyncio from typing import Dict -from llama_toolchain.distribution.datatypes import Adapter, Api +from llama_toolchain.distribution.datatypes import Api, ProviderSpec from .config import SafetyConfig from .api.endpoints import * # noqa @@ -23,7 +23,7 @@ from .shields import ( ) -async def get_adapter_impl(config: SafetyConfig, _deps: Dict[Api, Adapter]): +async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]): assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" impl = MetaReferenceSafetyImpl(config)