Adapter -> Provider

This commit is contained in:
Ashwin Bharambe 2024-08-05 13:26:29 -07:00
parent db3e6dda07
commit 65a9e40174
15 changed files with 119 additions and 110 deletions

View file

@ -7,7 +7,7 @@
from llama_toolchain.agentic_system.api import AgenticSystem from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.distribution.datatypes import Adapter, Api from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference from llama_toolchain.inference.api import Inference
from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api import Safety
@ -44,7 +44,7 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
async def get_adapter_impl(config: AgenticSystemConfig, deps: Dict[Api, Adapter]): async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
assert isinstance( assert isinstance(
config, AgenticSystemConfig config, AgenticSystemConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"

View file

@ -6,14 +6,14 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_agentic_system_adapters() -> List[Adapter]: def available_agentic_system_providers() -> List[ProviderSpec]:
return [ return [
SourceAdapter( InlineProviderSpec(
api=Api.agentic_system, api=Api.agentic_system,
adapter_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
"torch", "torch",
@ -21,7 +21,7 @@ def available_agentic_system_adapters() -> List[Adapter]:
], ],
module="llama_toolchain.agentic_system.agentic_system", module="llama_toolchain.agentic_system.agentic_system",
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig", config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
adapter_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety, Api.safety,
], ],

View file

@ -63,7 +63,7 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.common.exec import run_command from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter from llama_toolchain.distribution.datatypes import RemoteProviderSpec
from llama_toolchain.distribution.dynamic import instantiate_class_type from llama_toolchain.distribution.dynamic import instantiate_class_type
python_exe = run_command(shlex.split("which python")) python_exe = run_command(shlex.split("which python"))
@ -84,28 +84,28 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
with open(config_path, "r") as fp: with open(config_path, "r") as fp:
existing_config = yaml.safe_load(fp) existing_config = yaml.safe_load(fp)
adapter_configs = {} provider_configs = {}
for api, adapter in dist.adapters.items(): for api, provider_spec in dist.provider_specs.items():
if isinstance(adapter, PassthroughApiAdapter): if isinstance(provider_spec, RemoteProviderSpec):
adapter_configs[api.value] = adapter.dict() provider_configs[api.value] = provider_spec.dict()
else: else:
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(adapter.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = prompt_for_config( config = prompt_for_config(
config_type, config_type,
( (
config_type(**existing_config["adapters"][api.value]) config_type(**existing_config["providers"][api.value])
if existing_config and api.value in existing_config["adapters"] if existing_config and api.value in existing_config["providers"]
else None else None
), ),
) )
adapter_configs[api.value] = { provider_configs[api.value] = {
"adapter_id": adapter.adapter_id, "provider_id": provider_spec.provider_id,
**config.dict(), **config.dict(),
} }
dist_config = { dist_config = {
"adapters": adapter_configs, "providers": provider_configs,
"conda_env": conda_env, "conda_env": conda_env,
} }

View file

@ -30,7 +30,7 @@ class DistributionCreate(Subcommand):
required=True, required=True,
) )
# for each Api the user wants to support, we should # for each Api the user wants to support, we should
# get the list of available adapters, ask which one the user # get the list of available providers, ask which one the user
# wants to pick and then ask for their configuration. # wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:

View file

@ -33,17 +33,17 @@ class DistributionList(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [ headers = [
"Name", "Name",
"Adapters", "ProviderSpecs",
"Description", "Description",
] ]
rows = [] rows = []
for dist in available_distributions(): for dist in available_distributions():
adapters = {k.value: v.adapter_id for k, v in dist.adapters.items()} providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()}
rows.append( rows.append(
[ [
dist.name, dist.name,
json.dumps(adapters, indent=2), json.dumps(providers, indent=2),
dist.description, dist.description,
] ]
) )

View file

@ -26,13 +26,13 @@ class ApiEndpoint(BaseModel):
@json_schema_type @json_schema_type
class Adapter(BaseModel): class ProviderSpec(BaseModel):
api: Api api: Api
adapter_id: str provider_id: str
@json_schema_type @json_schema_type
class SourceAdapter(Adapter): class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field( pip_packages: List[str] = Field(
default_factory=list, default_factory=list,
description="The pip dependencies needed for this implementation", description="The pip dependencies needed for this implementation",
@ -42,21 +42,21 @@ class SourceAdapter(Adapter):
description=""" description="""
Fully-qualified name of the module to import. The module is expected to have: Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the local implementation - `get_provider_impl(config, deps)`: returns the local implementation
""", """,
) )
config_class: str = Field( config_class: str = Field(
..., ...,
description="Fully-qualified classname of the config for this adapter", description="Fully-qualified classname of the config for this provider",
) )
adapter_dependencies: List[Api] = Field( api_dependencies: List[Api] = Field(
default_factory=list, default_factory=list,
description="Higher-level API surfaces may depend on other adapters to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
@json_schema_type @json_schema_type
class PassthroughApiAdapter(Adapter): class RemoteProviderSpec(ProviderSpec):
base_url: str = Field(..., description="The base URL for the llama stack provider") base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field( headers: Dict[str, str] = Field(
default_factory=dict, default_factory=dict,
@ -75,12 +75,12 @@ class Distribution(BaseModel):
name: str name: str
description: str description: str
adapters: Dict[Api, Adapter] = Field( provider_specs: Dict[Api, ProviderSpec] = Field(
default_factory=dict, default_factory=dict,
description="The API surfaces provided by this distribution", description="Provider specifications for each of the APIs provided by this distribution",
) )
additional_pip_packages: List[str] = Field( additional_pip_packages: List[str] = Field(
default_factory=list, default_factory=list,
description="Additional pip packages beyond those required by the adapters", description="Additional pip packages beyond those required by the providers",
) )

View file

@ -11,16 +11,16 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety from llama_toolchain.safety.api.endpoints import Safety
from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
def distribution_dependencies(distribution: Distribution) -> List[str]: def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies # only consider InlineProviderSpecs when calculating dependencies
return [ return [
dep dep
for adapter in distribution.adapters.values() for provider_spec in distribution.provider_specs.values()
if isinstance(adapter, SourceAdapter) if isinstance(provider_spec, InlineProviderSpec)
for dep in adapter.pip_packages for dep in provider_spec.pip_packages
] + distribution.additional_pip_packages ] + distribution.additional_pip_packages

View file

@ -8,7 +8,7 @@ import asyncio
import importlib import importlib
from typing import Any, Dict from typing import Any, Dict
from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
@ -18,17 +18,19 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
def instantiate_adapter( def instantiate_provider(
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter] provider_spec: InlineProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
): ):
module = importlib.import_module(adapter.module) module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(adapter.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**adapter_config) config = config_type(**provider_config)
return asyncio.run(module.get_adapter_impl(config, deps)) return asyncio.run(module.get_provider_impl(config, deps))
def instantiate_client(adapter: PassthroughApiAdapter, base_url: str): def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
module = importlib.import_module(adapter.module) module = importlib.import_module(provider_spec.module)
return asyncio.run(module.get_client_impl(base_url)) return asyncio.run(module.get_client_impl(base_url))

View file

@ -7,12 +7,12 @@
from functools import lru_cache from functools import lru_cache
from typing import List, Optional from typing import List, Optional
from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.adapters import available_inference_adapters from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.adapters import available_safety_adapters from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import Api, Distribution, PassthroughApiAdapter from .datatypes import Api, Distribution, RemoteProviderSpec
# This is currently duplicated from `requirements.txt` with a few minor changes # This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies # dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
@ -49,10 +49,10 @@ def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client" return f"llama_toolchain.{api.value}.client"
def passthrough(api: Api, port: int) -> PassthroughApiAdapter: def remote(api: Api, port: int) -> RemoteProviderSpec:
return PassthroughApiAdapter( return RemoteProviderSpec(
api=api, api=api,
adapter_id=f"{api.value}-passthrough", provider_id=f"{api.value}-remote",
base_url=f"http://localhost:{port}", base_url=f"http://localhost:{port}",
module=client_module(api), module=client_module(api),
) )
@ -60,25 +60,28 @@ def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
@lru_cache() @lru_cache()
def available_distributions() -> List[Distribution]: def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()} inference_providers_by_id = {
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()} a.provider_id: a for a in available_inference_providers()
agentic_system_adapters_by_id = { }
a.adapter_id: a for a in available_agentic_system_adapters() safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
} }
return [ return [
Distribution( Distribution(
name="local-source", name="local-inline",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs", description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES, additional_pip_packages=COMMON_DEPENDENCIES,
adapters={ provider_specs={
Api.inference: inference_adapters_by_id["meta-reference"], Api.inference: inference_providers_by_id["meta-reference"],
Api.safety: safety_adapters_by_id["meta-reference"], Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"], Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
}, },
), ),
# NOTE: this hardcodes the ports to which things point to
Distribution( Distribution(
name="full-passthrough", name="full-remote",
description="Point to remote services for all llama stack APIs", description="Point to remote services for all llama stack APIs",
additional_pip_packages=[ additional_pip_packages=[
"python-dotenv", "python-dotenv",
@ -94,20 +97,20 @@ def available_distributions() -> List[Distribution]:
"pydantic_core==2.18.2", "pydantic_core==2.18.2",
"uvicorn", "uvicorn",
], ],
adapters={ provider_specs={
Api.inference: passthrough(Api.inference, 5001), Api.inference: remote(Api.inference, 5001),
Api.safety: passthrough(Api.safety, 5001), Api.safety: remote(Api.safety, 5001),
Api.agentic_system: passthrough(Api.agentic_system, 5001), Api.agentic_system: remote(Api.agentic_system, 5001),
}, },
), ),
Distribution( Distribution(
name="local-ollama", name="local-ollama",
description="Like local-source, but use ollama for running LLM inference", description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES, additional_pip_packages=COMMON_DEPENDENCIES,
adapters={ provider_specs={
Api.inference: inference_adapters_by_id["meta-ollama"], Api.inference: inference_providers_by_id["meta-ollama"],
Api.safety: safety_adapters_by_id["meta-reference"], Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"], Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
}, },
), ),
] ]

View file

@ -36,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution from .registry import resolve_distribution
@ -226,15 +226,15 @@ def create_dynamic_typed_route(func: Any):
return endpoint return endpoint
def topological_sort(adapters: List[Adapter]) -> List[Adapter]: def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in adapters} by_id = {x.api: x for x in providers}
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]): def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api) visited.add(a.api)
if not isinstance(a, PassthroughApiAdapter): if not isinstance(a, RemoteProviderSpec):
for api in a.adapter_dependencies: for api in a.api_dependencies:
if api not in visited: if api not in visited:
dfs(by_id[api], visited, stack) dfs(by_id[api], visited, stack)
@ -243,7 +243,7 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
visited = set() visited = set()
stack = [] stack = []
for a in adapters: for a in providers:
if a.api not in visited: if a.api not in visited:
dfs(a, visited, stack) dfs(a, visited, stack)
@ -251,23 +251,25 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]: def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
adapter_configs = config["adapters"] provider_configs = config["providers"]
adapters = topological_sort(dist.adapters.values()) provider_specs = topological_sort(dist.provider_specs.values())
impls = {} impls = {}
for adapter in adapters: for provider_spec in provider_specs:
api = adapter.api api = provider_spec.api
if api.value not in adapter_configs: if api.value not in provider_configs:
raise ValueError( raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config" f"Could not find provider_spec config for {api}. Please add it to the config"
) )
adapter_config = adapter_configs[api.value] provider_config = provider_configs[api.value]
if isinstance(adapter, PassthroughApiAdapter): if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/")) impls[api] = instantiate_client(
provider_spec, provider_spec.base_url.rstrip("/")
)
else: else:
deps = {api: impls[api] for api in adapter.adapter_dependencies} deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps) impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl impls[api] = impl
return impls return impls
@ -288,12 +290,12 @@ def main(
all_endpoints = api_endpoints() all_endpoints = api_endpoints()
impls = resolve_impls(dist, config) impls = resolve_impls(dist, config)
for adapter in dist.adapters.values(): for provider_spec in dist.provider_specs.values():
api = adapter.api api = provider_spec.api
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter): if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints: for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route url = provider_spec.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url) create_dynamic_passthrough(url)
) )

View file

@ -11,7 +11,7 @@ from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Adapter, Api from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from .api.config import MetaReferenceImplConfig from .api.config import MetaReferenceImplConfig
from .api.datatypes import ( from .api.datatypes import (
@ -29,7 +29,9 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
async def get_adapter_impl(config: MetaReferenceImplConfig, _deps: Dict[Api, Adapter]): async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance( assert isinstance(
config, MetaReferenceImplConfig config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"

View file

@ -37,7 +37,7 @@ from .api.endpoints import (
) )
def get_adapter_impl(config: OllamaImplConfig) -> Inference: def get_provider_impl(config: OllamaImplConfig) -> Inference:
assert isinstance( assert isinstance(
config, OllamaImplConfig config, OllamaImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"

View file

@ -6,14 +6,14 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_inference_adapters() -> List[Adapter]: def available_inference_providers() -> List[ProviderSpec]:
return [ return [
SourceAdapter( InlineProviderSpec(
api=Api.inference, api=Api.inference,
adapter_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"torch", "torch",
"zmq", "zmq",
@ -21,9 +21,9 @@ def available_inference_adapters() -> List[Adapter]:
module="llama_toolchain.inference.inference", module="llama_toolchain.inference.inference",
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig", config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
), ),
SourceAdapter( InlineProviderSpec(
api=Api.inference, api=Api.inference,
adapter_id="meta-ollama", provider_id="meta-ollama",
pip_packages=[ pip_packages=[
"ollama", "ollama",
], ],

View file

@ -6,14 +6,14 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_safety_adapters() -> List[Adapter]: def available_safety_providers() -> List[ProviderSpec]:
return [ return [
SourceAdapter( InlineProviderSpec(
api=Api.safety, api=Api.safety,
adapter_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
"torch", "torch",

View file

@ -8,7 +8,7 @@ import asyncio
from typing import Dict from typing import Dict
from llama_toolchain.distribution.datatypes import Adapter, Api from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from .config import SafetyConfig from .config import SafetyConfig
from .api.endpoints import * # noqa from .api.endpoints import * # noqa
@ -23,7 +23,7 @@ from .shields import (
) )
async def get_adapter_impl(config: SafetyConfig, _deps: Dict[Api, Adapter]): async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config) impl = MetaReferenceSafetyImpl(config)