Adapter -> Provider

This commit is contained in:
Ashwin Bharambe 2024-08-05 13:26:29 -07:00
parent db3e6dda07
commit 65a9e40174
15 changed files with 119 additions and 110 deletions

View file

@ -7,7 +7,7 @@
from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.distribution.datatypes import Adapter, Api
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference
from llama_toolchain.safety.api import Safety
@ -44,7 +44,7 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO)
async def get_adapter_impl(config: AgenticSystemConfig, deps: Dict[Api, Adapter]):
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
assert isinstance(
config, AgenticSystemConfig
), f"Unexpected config type: {type(config)}"

View file

@ -6,14 +6,14 @@
from typing import List
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_agentic_system_adapters() -> List[Adapter]:
def available_agentic_system_providers() -> List[ProviderSpec]:
return [
SourceAdapter(
InlineProviderSpec(
api=Api.agentic_system,
adapter_id="meta-reference",
provider_id="meta-reference",
pip_packages=[
"codeshield",
"torch",
@ -21,7 +21,7 @@ def available_agentic_system_adapters() -> List[Adapter]:
],
module="llama_toolchain.agentic_system.agentic_system",
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
adapter_dependencies=[
api_dependencies=[
Api.inference,
Api.safety,
],

View file

@ -63,7 +63,7 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
from llama_toolchain.distribution.datatypes import RemoteProviderSpec
from llama_toolchain.distribution.dynamic import instantiate_class_type
python_exe = run_command(shlex.split("which python"))
@ -84,28 +84,28 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
with open(config_path, "r") as fp:
existing_config = yaml.safe_load(fp)
adapter_configs = {}
for api, adapter in dist.adapters.items():
if isinstance(adapter, PassthroughApiAdapter):
adapter_configs[api.value] = adapter.dict()
provider_configs = {}
for api, provider_spec in dist.provider_specs.items():
if isinstance(provider_spec, RemoteProviderSpec):
provider_configs[api.value] = provider_spec.dict()
else:
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(adapter.config_class)
config_type = instantiate_class_type(provider_spec.config_class)
config = prompt_for_config(
config_type,
(
config_type(**existing_config["adapters"][api.value])
if existing_config and api.value in existing_config["adapters"]
config_type(**existing_config["providers"][api.value])
if existing_config and api.value in existing_config["providers"]
else None
),
)
adapter_configs[api.value] = {
"adapter_id": adapter.adapter_id,
provider_configs[api.value] = {
"provider_id": provider_spec.provider_id,
**config.dict(),
}
dist_config = {
"adapters": adapter_configs,
"providers": provider_configs,
"conda_env": conda_env,
}

View file

@ -30,7 +30,7 @@ class DistributionCreate(Subcommand):
required=True,
)
# for each Api the user wants to support, we should
# get the list of available adapters, ask which one the user
# get the list of available providers, ask which one the user
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:

View file

@ -33,17 +33,17 @@ class DistributionList(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Name",
"Adapters",
"ProviderSpecs",
"Description",
]
rows = []
for dist in available_distributions():
adapters = {k.value: v.adapter_id for k, v in dist.adapters.items()}
providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()}
rows.append(
[
dist.name,
json.dumps(adapters, indent=2),
json.dumps(providers, indent=2),
dist.description,
]
)

View file

@ -26,13 +26,13 @@ class ApiEndpoint(BaseModel):
@json_schema_type
class Adapter(BaseModel):
class ProviderSpec(BaseModel):
api: Api
adapter_id: str
provider_id: str
@json_schema_type
class SourceAdapter(Adapter):
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
@ -42,21 +42,21 @@ class SourceAdapter(Adapter):
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the local implementation
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this adapter",
description="Fully-qualified classname of the config for this provider",
)
adapter_dependencies: List[Api] = Field(
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class PassthroughApiAdapter(Adapter):
class RemoteProviderSpec(ProviderSpec):
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
@ -75,12 +75,12 @@ class Distribution(BaseModel):
name: str
description: str
adapters: Dict[Api, Adapter] = Field(
provider_specs: Dict[Api, ProviderSpec] = Field(
default_factory=dict,
description="The API surfaces provided by this distribution",
description="Provider specifications for each of the APIs provided by this distribution",
)
additional_pip_packages: List[str] = Field(
default_factory=list,
description="Additional pip packages beyond those required by the adapters",
description="Additional pip packages beyond those required by the providers",
)

View file

@ -11,16 +11,16 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter
from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages
] + distribution.additional_pip_packages

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name):
@ -18,17 +18,19 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api
def instantiate_adapter(
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
def instantiate_provider(
provider_spec: InlineProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
):
module = importlib.import_module(adapter.module)
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(adapter.config_class)
config = config_type(**adapter_config)
return asyncio.run(module.get_adapter_impl(config, deps))
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config)
return asyncio.run(module.get_provider_impl(config, deps))
def instantiate_client(adapter: PassthroughApiAdapter, base_url: str):
module = importlib.import_module(adapter.module)
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
module = importlib.import_module(provider_spec.module)
return asyncio.run(module.get_client_impl(base_url))

View file

@ -7,12 +7,12 @@
from functools import lru_cache
from typing import List, Optional
from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.adapters import available_inference_adapters
from llama_toolchain.safety.adapters import available_safety_adapters
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import Api, Distribution, PassthroughApiAdapter
from .datatypes import Api, Distribution, RemoteProviderSpec
# This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
@ -49,10 +49,10 @@ def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
return PassthroughApiAdapter(
def remote(api: Api, port: int) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
adapter_id=f"{api.value}-passthrough",
provider_id=f"{api.value}-remote",
base_url=f"http://localhost:{port}",
module=client_module(api),
)
@ -60,25 +60,28 @@ def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
@lru_cache()
def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()}
agentic_system_adapters_by_id = {
a.adapter_id: a for a in available_agentic_system_adapters()
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
return [
Distribution(
name="local-source",
name="local-inline",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
Api.inference: inference_adapters_by_id["meta-reference"],
Api.safety: safety_adapters_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
provider_specs={
Api.inference: inference_providers_by_id["meta-reference"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
},
),
# NOTE: this hardcodes the ports to which things point to
Distribution(
name="full-passthrough",
name="full-remote",
description="Point to remote services for all llama stack APIs",
additional_pip_packages=[
"python-dotenv",
@ -94,20 +97,20 @@ def available_distributions() -> List[Distribution]:
"pydantic_core==2.18.2",
"uvicorn",
],
adapters={
Api.inference: passthrough(Api.inference, 5001),
Api.safety: passthrough(Api.safety, 5001),
Api.agentic_system: passthrough(Api.agentic_system, 5001),
provider_specs={
Api.inference: remote(Api.inference, 5001),
Api.safety: remote(Api.safety, 5001),
Api.agentic_system: remote(Api.agentic_system, 5001),
},
),
Distribution(
name="local-ollama",
description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
Api.inference: inference_adapters_by_id["meta-ollama"],
Api.safety: safety_adapters_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
provider_specs={
Api.inference: inference_providers_by_id["meta-ollama"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
},
),
]

View file

@ -36,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution
@ -226,15 +226,15 @@ def create_dynamic_typed_route(func: Any):
return endpoint
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in adapters}
by_id = {x.api: x for x in providers}
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
if not isinstance(a, PassthroughApiAdapter):
for api in a.adapter_dependencies:
if not isinstance(a, RemoteProviderSpec):
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
@ -243,7 +243,7 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
visited = set()
stack = []
for a in adapters:
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
@ -251,23 +251,25 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values())
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
impls = {}
for adapter in adapters:
api = adapter.api
if api.value not in adapter_configs:
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config"
f"Could not find provider_spec config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[api.value]
if isinstance(adapter, PassthroughApiAdapter):
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_spec.base_url.rstrip("/")
)
else:
deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
@ -288,12 +290,12 @@ def main(
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
for adapter in dist.adapters.values():
api = adapter.api
for provider_spec in dist.provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter):
if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route
url = provider_spec.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)

View file

@ -11,7 +11,7 @@ from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Adapter, Api
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from .api.config import MetaReferenceImplConfig
from .api.datatypes import (
@ -29,7 +29,9 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator
async def get_adapter_impl(config: MetaReferenceImplConfig, _deps: Dict[Api, Adapter]):
async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"

View file

@ -37,7 +37,7 @@ from .api.endpoints import (
)
def get_adapter_impl(config: OllamaImplConfig) -> Inference:
def get_provider_impl(config: OllamaImplConfig) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"

View file

@ -6,14 +6,14 @@
from typing import List
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_inference_adapters() -> List[Adapter]:
def available_inference_providers() -> List[ProviderSpec]:
return [
SourceAdapter(
InlineProviderSpec(
api=Api.inference,
adapter_id="meta-reference",
provider_id="meta-reference",
pip_packages=[
"torch",
"zmq",
@ -21,9 +21,9 @@ def available_inference_adapters() -> List[Adapter]:
module="llama_toolchain.inference.inference",
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
),
SourceAdapter(
InlineProviderSpec(
api=Api.inference,
adapter_id="meta-ollama",
provider_id="meta-ollama",
pip_packages=[
"ollama",
],

View file

@ -6,14 +6,14 @@
from typing import List
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_safety_adapters() -> List[Adapter]:
def available_safety_providers() -> List[ProviderSpec]:
return [
SourceAdapter(
InlineProviderSpec(
api=Api.safety,
adapter_id="meta-reference",
provider_id="meta-reference",
pip_packages=[
"codeshield",
"torch",

View file

@ -8,7 +8,7 @@ import asyncio
from typing import Dict
from llama_toolchain.distribution.datatypes import Adapter, Api
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from .config import SafetyConfig
from .api.endpoints import * # noqa
@ -23,7 +23,7 @@ from .shields import (
)
async def get_adapter_impl(config: SafetyConfig, _deps: Dict[Api, Adapter]):
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)