Introduce a "Router" layer for providers

Some providers need to be factorized and considered as thin routing
layers on top of other providers. Consider two examples:

- The inference API should be a routing layer over inference providers,
  routed using the "model" key
- The memory banks API is another instance where various memory bank
  types will be provided by independent providers (e.g., a vector store
  is served by Chroma while a keyvalue memory can be served by Redis or
  PGVector)

This commit introduces a generalized routing layer for this purpose.
This commit is contained in:
Ashwin Bharambe 2024-09-16 10:38:11 -07:00
parent 5c1f2616b5
commit b6a3ef51da
12 changed files with 384 additions and 118 deletions

View file

@ -46,7 +46,7 @@ class StackBuild(Subcommand):
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.package import ApiInput, build_package, ImageType
from llama_toolchain.core.package import ApiInput, build_image, ImageType
from termcolor import cprint
# save build.yaml spec for building same distribution again
@ -66,7 +66,7 @@ class StackBuild(Subcommand):
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
build_package(build_config, build_file_path)
build_image(build_config, build_file_path)
cprint(
f"Build spec configuration saved at {str(build_file_path)}",

View file

@ -105,13 +105,6 @@ class StackConfigure(Subcommand):
image_name = build_config.name.replace("::", "-")
run_config_file = builds_dir / f"{image_name}-run.yaml"
api2providers = build_config.distribution_spec.providers
stub_config = {
api_str: {"provider_id": provider}
for api_str, provider in api2providers.items()
}
if run_config_file.exists():
cprint(
f"Configuration already exists for {build_config.name}. Will overwrite...",
@ -123,10 +116,12 @@ class StackConfigure(Subcommand):
config = StackRunConfig(
built_at=datetime.now(),
image_name=image_name,
providers=stub_config,
apis_to_serve=[],
provider_map={},
)
config.providers = configure_api_providers(config.providers)
config = configure_api_providers(config, build_config.distribution_spec)
config.docker_image = (
image_name if build_config.image_type == "docker" else None
)

View file

@ -27,6 +27,12 @@ def is_list_of_primitives(field_type):
return False
def is_basemodel_without_fields(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
)
def can_recurse(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
@ -151,6 +157,11 @@ def prompt_for_config(
if get_origin(field_type) is Literal:
continue
# Skip fields with no type annotations
if is_basemodel_without_fields(field_type):
config_data[field_name] = field_type()
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True:
@ -254,6 +265,20 @@ def prompt_for_config(
print(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError(
"Input must be a JSON-encoded dictionary"
)
except json.JSONDecodeError:
print(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue
# Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass(
field_type, BaseModel

View file

@ -4,47 +4,87 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
from llama_toolchain.core.datatypes import * # noqa: F403
from termcolor import cprint
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.core.distribution import api_providers
from llama_toolchain.core.distribution import api_providers, stack_apis
from llama_toolchain.core.dynamic import instantiate_class_type
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
# These are hacks so we can re-use the `prompt_for_config` utility
# This needs a bunch of work to be made very user friendly.
class ReqApis(BaseModel):
apis_to_serve: List[str]
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
return BaseModelWithConfig
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig:
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
print("Enter comma-separated list of APIs to serve:")
apis = config.apis_to_serve or list(spec.providers.keys())
apis = [a for a in apis if a != "telemetry"]
req_apis = ReqApis(
apis_to_serve=apis,
)
req_apis = prompt_for_config(ReqApis, req_apis)
print("")
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
provider_configs = {}
for api_str, stub_config in existing_configs.items():
apis_to_serve = req_apis.apis_to_serve + ["telemetry"]
for api_str in apis_to_serve:
if api_str not in apis:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
api = Api(api_str)
providers = all_providers[api]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
if isinstance(spec.providers[api_str], list):
print(
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
)
routing_entries = []
for p in spec.providers[api_str]:
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
wrapper_type = make_routing_entry_type(config_type)
rt_entry = prompt_for_config(wrapper_type, None)
# TODO: we need to validate the routing keys
routing_entries.append(
ProviderRoutingEntry(
provider_id=p,
routing_key=rt_entry.routing_key,
config=rt_entry.config.dict(),
)
)
config.provider_map[api_str] = routing_entries
else:
provider_spec = all_providers[api][spec.providers[api_str]]
config_type = instantiate_class_type(provider_spec.config_class)
cfg = prompt_for_config(config_type, None)
config.provider_map[api_str] = GenericProviderConfig(
provider_id=spec.providers[api_str],
config=cfg.dict(),
)
provider_spec = providers[provider_id]
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except Exception:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")
provider_configs[api_str] = {
"provider_id": provider_id,
**provider_config.dict(),
}
return provider_configs
return config

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from llama_models.schema_utils import json_schema_type
@ -43,6 +43,33 @@ class ProviderSpec(BaseModel):
)
@json_schema_type
class RouterProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on RouterProviderSpec")
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
@ -156,12 +183,23 @@ class DistributionSpec(BaseModel):
description="Description of the distribution",
)
docker_image: Optional[str] = None
providers: Dict[str, str] = Field(
providers: Dict[str, Union[str, List[str]]] = Field(
default_factory=dict,
description="Provider Types for each of the APIs provided by this distribution",
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""",
)
@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
@json_schema_type
class StackRunConfig(BaseModel):
built_at: datetime
@ -181,12 +219,22 @@ this could be just a hash
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
providers: Dict[str, Any] = Field(
default_factory=dict,
apis_to_serve: List[str] = Field(
description="""
Provider configurations for each of the APIs provided by this package. This includes configurations for
the dependencies of these providers as well.
""",
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
provider_map: Dict[str, ProviderMapEntry] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
As examples:
- the "inference" API interprets the routing_key as a "model"
- the "memory" API interprets the routing_key as the type of a "memory bank"
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
)

View file

@ -4,11 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib
from typing import Any, Dict
from .datatypes import ProviderSpec, RemoteProviderSpec
from llama_toolchain.core.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name):
@ -18,25 +17,50 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api
def instantiate_provider(
async def instantiate_provider(
provider_spec: ProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
deps: Dict[str, Any],
provider_config: ProviderMapEntry,
):
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, RouterProviderSpec):
method = "get_router_impl"
assert isinstance(provider_config, list)
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in provider_config:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [inner_impls, deps]
else:
method = "get_provider_impl"
config = config_type(**provider_config)
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = asyncio.run(fn(config, deps))
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -4,22 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
from enum import Enum
from typing import List, Optional
import pkg_resources
import yaml
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.common.serialize import EnumEncoder
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.core.datatypes import * # noqa: F403
from pathlib import Path
@ -41,7 +35,7 @@ class ApiInput(BaseModel):
provider: str
def build_package(build_config: BuildConfig, build_file_path: Path):
def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
@ -49,17 +43,28 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
# extend package dependencies based on providers spec
all_providers = api_providers()
for api_str, provider in build_config.distribution_spec.providers.items():
for (
api_str,
provider_or_providers,
) in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)]
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
providers = (
provider_or_providers
if isinstance(provider_or_providers, list)
else [provider_or_providers]
)
for provider in providers:
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(

View file

@ -9,6 +9,7 @@ import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
@ -44,8 +45,8 @@ from llama_toolchain.telemetry.tracing import (
SpanStatus,
start_trace,
)
from llama_toolchain.core.datatypes import * # noqa: F403
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(provider_specs.values())
def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
impls = {}
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
raise ValueError(
f"Could not find provider_spec config for {api}. Please add it to the config"
async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry],
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
for api_str, item in provider_map.items():
api = Api(api_str)
providers = all_providers[api]
if isinstance(item, GenericProviderConfig):
if item.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
else:
assert isinstance(item, list)
inner_specs = []
for rt_entry in item:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_toolchain.{api.value.lower()}.router",
api_dependencies=[],
inner_specs=inner_specs,
)
if isinstance(provider_spec, InlineProviderSpec):
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else:
deps = {}
provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps)
sorted_specs = topological_sort(specs.values())
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, provider_map[api.value])
impls[api] = impl
return impls
return impls, specs
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
all_endpoints = api_endpoints()
all_providers = api_providers()
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
impls, specs = asyncio.run(resolve_impls(config.provider_map))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
for provider_spec in provider_specs.values():
api = provider_spec.api
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
provider_spec = specs[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None

View file

@ -6,6 +6,7 @@
import asyncio
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

View file

@ -26,16 +26,16 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
Api.memory,
AdapterSpec(
adapter_id="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_toolchain.memory.adapters.chroma",
),
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
Api.memory,
AdapterSpec(
adapter_id="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
module="llama_toolchain.memory.adapters.pgvector",

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from llama_toolchain.core.datatypes import Api
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
from .router import MemoryRouterImpl
impl = MemoryRouterImpl(inner_impls, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from llama_toolchain.core.datatypes import Api
from llama_toolchain.memory.api import * # noqa: F403
class MemoryRouterImpl(Memory):
"""Routes to an provider based on the memory bank type"""
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
deps: List[Api],
) -> None:
self.deps = deps
bank_types = [v.value for v in MemoryBankType]
self.providers = {}
for routing_key, provider_impl in inner_impls:
if routing_key not in bank_types:
raise ValueError(
f"Unknown routing key `{routing_key}` for memory bank type"
)
self.providers[routing_key] = provider_impl
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
def get_provider(self, bank_type):
if bank_type not in self.providers:
raise ValueError(f"Memory bank type {bank_type} not supported")
return self.providers[bank_type]
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
provider = self.get_provider(config.type)
bank = await provider.create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = config.type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.insert_documents(bank_id, documents, ttl_seconds)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.query_documents(bank_id, query, params)