mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Introduce a "Router" layer for providers
Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose.
This commit is contained in:
parent
5c1f2616b5
commit
b6a3ef51da
12 changed files with 384 additions and 118 deletions
|
@ -46,7 +46,7 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
from llama_toolchain.common.serialize import EnumEncoder
|
||||||
from llama_toolchain.core.package import ApiInput, build_package, ImageType
|
from llama_toolchain.core.package import ApiInput, build_image, ImageType
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
# save build.yaml spec for building same distribution again
|
# save build.yaml spec for building same distribution again
|
||||||
|
@ -66,7 +66,7 @@ class StackBuild(Subcommand):
|
||||||
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
build_package(build_config, build_file_path)
|
build_image(build_config, build_file_path)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Build spec configuration saved at {str(build_file_path)}",
|
f"Build spec configuration saved at {str(build_file_path)}",
|
||||||
|
|
|
@ -105,13 +105,6 @@ class StackConfigure(Subcommand):
|
||||||
image_name = build_config.name.replace("::", "-")
|
image_name = build_config.name.replace("::", "-")
|
||||||
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
||||||
|
|
||||||
api2providers = build_config.distribution_spec.providers
|
|
||||||
|
|
||||||
stub_config = {
|
|
||||||
api_str: {"provider_id": provider}
|
|
||||||
for api_str, provider in api2providers.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
if run_config_file.exists():
|
if run_config_file.exists():
|
||||||
cprint(
|
cprint(
|
||||||
f"Configuration already exists for {build_config.name}. Will overwrite...",
|
f"Configuration already exists for {build_config.name}. Will overwrite...",
|
||||||
|
@ -123,10 +116,12 @@ class StackConfigure(Subcommand):
|
||||||
config = StackRunConfig(
|
config = StackRunConfig(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
providers=stub_config,
|
apis_to_serve=[],
|
||||||
|
provider_map={},
|
||||||
)
|
)
|
||||||
|
|
||||||
config.providers = configure_api_providers(config.providers)
|
config = configure_api_providers(config, build_config.distribution_spec)
|
||||||
|
|
||||||
config.docker_image = (
|
config.docker_image = (
|
||||||
image_name if build_config.image_type == "docker" else None
|
image_name if build_config.image_type == "docker" else None
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,6 +27,12 @@ def is_list_of_primitives(field_type):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_basemodel_without_fields(typ):
|
||||||
|
return (
|
||||||
|
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def can_recurse(typ):
|
def can_recurse(typ):
|
||||||
return (
|
return (
|
||||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||||
|
@ -151,6 +157,11 @@ def prompt_for_config(
|
||||||
if get_origin(field_type) is Literal:
|
if get_origin(field_type) is Literal:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip fields with no type annotations
|
||||||
|
if is_basemodel_without_fields(field_type):
|
||||||
|
config_data[field_name] = field_type()
|
||||||
|
continue
|
||||||
|
|
||||||
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
||||||
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
||||||
while True:
|
while True:
|
||||||
|
@ -254,6 +265,20 @@ def prompt_for_config(
|
||||||
print(f"{str(e)}")
|
print(f"{str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
elif get_origin(field_type) is dict:
|
||||||
|
try:
|
||||||
|
value = json.loads(user_input)
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"Input must be a JSON-encoded dictionary"
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(
|
||||||
|
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Convert the input to the correct type
|
# Convert the input to the correct type
|
||||||
elif inspect.isclass(field_type) and issubclass(
|
elif inspect.isclass(field_type) and issubclass(
|
||||||
field_type, BaseModel
|
field_type, BaseModel
|
||||||
|
|
|
@ -4,47 +4,87 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||||
from llama_toolchain.core.distribution import api_providers
|
from llama_toolchain.core.distribution import api_providers, stack_apis
|
||||||
from llama_toolchain.core.dynamic import instantiate_class_type
|
from llama_toolchain.core.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
# These are hacks so we can re-use the `prompt_for_config` utility
|
||||||
|
# This needs a bunch of work to be made very user friendly.
|
||||||
|
class ReqApis(BaseModel):
|
||||||
|
apis_to_serve: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
def make_routing_entry_type(config_class: Any):
|
||||||
|
class BaseModelWithConfig(BaseModel):
|
||||||
|
routing_key: str
|
||||||
|
config: config_class
|
||||||
|
|
||||||
|
return BaseModelWithConfig
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: make sure we can deal with existing configuration values correctly
|
||||||
|
# instead of just overwriting them
|
||||||
|
def configure_api_providers(
|
||||||
|
config: StackRunConfig, spec: DistributionSpec
|
||||||
|
) -> StackRunConfig:
|
||||||
|
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
|
||||||
|
print("Enter comma-separated list of APIs to serve:")
|
||||||
|
|
||||||
|
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||||
|
apis = [a for a in apis if a != "telemetry"]
|
||||||
|
req_apis = ReqApis(
|
||||||
|
apis_to_serve=apis,
|
||||||
|
)
|
||||||
|
req_apis = prompt_for_config(ReqApis, req_apis)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
apis = [v.value for v in stack_apis()]
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
|
|
||||||
provider_configs = {}
|
apis_to_serve = req_apis.apis_to_serve + ["telemetry"]
|
||||||
for api_str, stub_config in existing_configs.items():
|
for api_str in apis_to_serve:
|
||||||
|
if api_str not in apis:
|
||||||
|
raise ValueError(f"Unknown API `{api_str}`")
|
||||||
|
|
||||||
|
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
providers = all_providers[api]
|
if isinstance(spec.providers[api_str], list):
|
||||||
provider_id = stub_config["provider_id"]
|
print(
|
||||||
if provider_id not in providers:
|
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
|
||||||
raise ValueError(
|
)
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
|
routing_entries = []
|
||||||
|
for p in spec.providers[api_str]:
|
||||||
|
print(f"Configuring provider `{p}`...")
|
||||||
|
provider_spec = all_providers[api][p]
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
|
||||||
|
wrapper_type = make_routing_entry_type(config_type)
|
||||||
|
rt_entry = prompt_for_config(wrapper_type, None)
|
||||||
|
|
||||||
|
# TODO: we need to validate the routing keys
|
||||||
|
routing_entries.append(
|
||||||
|
ProviderRoutingEntry(
|
||||||
|
provider_id=p,
|
||||||
|
routing_key=rt_entry.routing_key,
|
||||||
|
config=rt_entry.config.dict(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
config.provider_map[api_str] = routing_entries
|
||||||
|
else:
|
||||||
|
provider_spec = all_providers[api][spec.providers[api_str]]
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
cfg = prompt_for_config(config_type, None)
|
||||||
|
config.provider_map[api_str] = GenericProviderConfig(
|
||||||
|
provider_id=spec.providers[api_str],
|
||||||
|
config=cfg.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_spec = providers[provider_id]
|
return config
|
||||||
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
|
||||||
|
|
||||||
try:
|
|
||||||
existing_provider_config = config_type(**stub_config)
|
|
||||||
except Exception:
|
|
||||||
existing_provider_config = None
|
|
||||||
|
|
||||||
provider_config = prompt_for_config(
|
|
||||||
config_type,
|
|
||||||
existing_provider_config,
|
|
||||||
)
|
|
||||||
print("")
|
|
||||||
|
|
||||||
provider_configs[api_str] = {
|
|
||||||
"provider_id": provider_id,
|
|
||||||
**provider_config.dict(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return provider_configs
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -43,6 +43,33 @@ class ProviderSpec(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RouterProviderSpec(ProviderSpec):
|
||||||
|
provider_id: str = "router"
|
||||||
|
config_class: str = ""
|
||||||
|
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
|
inner_specs: List[ProviderSpec]
|
||||||
|
module: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> List[str]:
|
||||||
|
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
class GenericProviderConfig(BaseModel):
|
||||||
|
provider_id: str
|
||||||
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_id: str = Field(
|
adapter_id: str = Field(
|
||||||
|
@ -156,12 +183,23 @@ class DistributionSpec(BaseModel):
|
||||||
description="Description of the distribution",
|
description="Description of the distribution",
|
||||||
)
|
)
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
providers: Dict[str, str] = Field(
|
providers: Dict[str, Union[str, List[str]]] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Provider Types for each of the APIs provided by this distribution",
|
description="""
|
||||||
|
Provider Types for each of the APIs provided by this distribution. If you
|
||||||
|
select multiple providers, you should provide an appropriate 'routing_map'
|
||||||
|
in the runtime configuration to help route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ProviderRoutingEntry(GenericProviderConfig):
|
||||||
|
routing_key: str
|
||||||
|
|
||||||
|
|
||||||
|
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
built_at: datetime
|
built_at: datetime
|
||||||
|
@ -181,12 +219,22 @@ this could be just a hash
|
||||||
default=None,
|
default=None,
|
||||||
description="Reference to the conda environment if this package refers to a conda environment",
|
description="Reference to the conda environment if this package refers to a conda environment",
|
||||||
)
|
)
|
||||||
providers: Dict[str, Any] = Field(
|
apis_to_serve: List[str] = Field(
|
||||||
default_factory=dict,
|
|
||||||
description="""
|
description="""
|
||||||
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
the dependencies of these providers as well.
|
)
|
||||||
""",
|
provider_map: Dict[str, ProviderMapEntry] = Field(
|
||||||
|
description="""
|
||||||
|
Provider configurations for each of the APIs provided by this package.
|
||||||
|
|
||||||
|
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
|
||||||
|
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
|
||||||
|
|
||||||
|
As examples:
|
||||||
|
- the "inference" API interprets the routing_key as a "model"
|
||||||
|
- the "memory" API interprets the routing_key as the type of a "memory bank"
|
||||||
|
|
||||||
|
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .datatypes import ProviderSpec, RemoteProviderSpec
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
@ -18,25 +17,50 @@ def instantiate_class_type(fully_qualified_name):
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider_spec: ProviderSpec,
|
provider_spec: ProviderSpec,
|
||||||
provider_config: Dict[str, Any],
|
deps: Dict[str, Any],
|
||||||
deps: Dict[str, ProviderSpec],
|
provider_config: ProviderMapEntry,
|
||||||
):
|
):
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
args = []
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
if provider_spec.adapter:
|
if provider_spec.adapter:
|
||||||
method = "get_adapter_impl"
|
method = "get_adapter_impl"
|
||||||
else:
|
else:
|
||||||
method = "get_client_impl"
|
method = "get_client_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, GenericProviderConfig)
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
config = config_type(**provider_config.config)
|
||||||
|
args = [config, deps]
|
||||||
|
elif isinstance(provider_spec, RouterProviderSpec):
|
||||||
|
method = "get_router_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, list)
|
||||||
|
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||||
|
inner_impls = []
|
||||||
|
for routing_entry in provider_config:
|
||||||
|
impl = await instantiate_provider(
|
||||||
|
inner_specs[routing_entry.provider_id],
|
||||||
|
deps,
|
||||||
|
routing_entry,
|
||||||
|
)
|
||||||
|
inner_impls.append((routing_entry.routing_key, impl))
|
||||||
|
|
||||||
|
config = None
|
||||||
|
args = [inner_impls, deps]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
config = config_type(**provider_config)
|
assert isinstance(provider_config, GenericProviderConfig)
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
config = config_type(**provider_config.config)
|
||||||
|
args = [config, deps]
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = asyncio.run(fn(config, deps))
|
impl = await fn(*args)
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,22 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.common.exec import run_with_pty
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -41,7 +35,7 @@ class ApiInput(BaseModel):
|
||||||
provider: str
|
provider: str
|
||||||
|
|
||||||
|
|
||||||
def build_package(build_config: BuildConfig, build_file_path: Path):
|
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
package_deps = Dependencies(
|
package_deps = Dependencies(
|
||||||
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
||||||
pip_packages=SERVER_DEPENDENCIES,
|
pip_packages=SERVER_DEPENDENCIES,
|
||||||
|
@ -49,17 +43,28 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
|
||||||
|
|
||||||
# extend package dependencies based on providers spec
|
# extend package dependencies based on providers spec
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
for api_str, provider in build_config.distribution_spec.providers.items():
|
for (
|
||||||
|
api_str,
|
||||||
|
provider_or_providers,
|
||||||
|
) in build_config.distribution_spec.providers.items():
|
||||||
providers_for_api = all_providers[Api(api_str)]
|
providers_for_api = all_providers[Api(api_str)]
|
||||||
if provider not in providers_for_api:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_spec = providers_for_api[provider]
|
providers = (
|
||||||
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
provider_or_providers
|
||||||
if provider_spec.docker_image:
|
if isinstance(provider_or_providers, list)
|
||||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
else [provider_or_providers]
|
||||||
|
)
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
if provider not in providers_for_api:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_spec = providers_for_api[provider]
|
||||||
|
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
||||||
|
if provider_spec.docker_image:
|
||||||
|
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||||
|
|
||||||
if build_config.image_type == ImageType.docker.value:
|
if build_config.image_type == ImageType.docker.value:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
|
|
|
@ -9,6 +9,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
AsyncGenerator as AsyncGeneratorABC,
|
AsyncGenerator as AsyncGeneratorABC,
|
||||||
AsyncIterator as AsyncIteratorABC,
|
AsyncIterator as AsyncIteratorABC,
|
||||||
|
@ -44,8 +45,8 @@ from llama_toolchain.telemetry.tracing import (
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
|
||||||
from .distribution import api_endpoints, api_providers
|
from .distribution import api_endpoints, api_providers
|
||||||
from .dynamic import instantiate_provider
|
from .dynamic import instantiate_provider
|
||||||
|
|
||||||
|
@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
def resolve_impls(
|
def snake_to_camel(snake_str):
|
||||||
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||||
) -> Dict[Api, Any]:
|
|
||||||
provider_configs = config["providers"]
|
|
||||||
provider_specs = topological_sort(provider_specs.values())
|
|
||||||
|
|
||||||
impls = {}
|
|
||||||
for provider_spec in provider_specs:
|
async def resolve_impls(
|
||||||
api = provider_spec.api
|
provider_map: Dict[str, ProviderMapEntry],
|
||||||
if api.value not in provider_configs:
|
) -> Dict[Api, Any]:
|
||||||
raise ValueError(
|
"""
|
||||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
Does two things:
|
||||||
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
|
"""
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
specs = {}
|
||||||
|
for api_str, item in provider_map.items():
|
||||||
|
api = Api(api_str)
|
||||||
|
providers = all_providers[api]
|
||||||
|
|
||||||
|
if isinstance(item, GenericProviderConfig):
|
||||||
|
if item.provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
specs[api] = providers[item.provider_id]
|
||||||
|
else:
|
||||||
|
assert isinstance(item, list)
|
||||||
|
inner_specs = []
|
||||||
|
for rt_entry in item:
|
||||||
|
if rt_entry.provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
|
||||||
|
specs[api] = RouterProviderSpec(
|
||||||
|
api=api,
|
||||||
|
module=f"llama_toolchain.{api.value.lower()}.router",
|
||||||
|
api_dependencies=[],
|
||||||
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(provider_spec, InlineProviderSpec):
|
sorted_specs = topological_sort(specs.values())
|
||||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
|
||||||
else:
|
impls = {}
|
||||||
deps = {}
|
for spec in sorted_specs:
|
||||||
provider_config = provider_configs[api.value]
|
api = spec.api
|
||||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
|
||||||
|
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||||
|
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
return impls
|
return impls, specs
|
||||||
|
|
||||||
|
|
||||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
with open(yaml_config, "r") as fp:
|
with open(yaml_config, "r") as fp:
|
||||||
config = yaml.safe_load(fp)
|
config = StackRunConfig(**yaml.safe_load(fp))
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
all_providers = api_providers()
|
|
||||||
|
|
||||||
provider_specs = {}
|
|
||||||
for api_str, provider_config in config["providers"].items():
|
|
||||||
api = Api(api_str)
|
|
||||||
providers = all_providers[api]
|
|
||||||
provider_id = provider_config["provider_id"]
|
|
||||||
if provider_id not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_specs[api] = providers[provider_id]
|
|
||||||
|
|
||||||
impls = resolve_impls(provider_specs, config)
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
for provider_spec in provider_specs.values():
|
all_endpoints = api_endpoints()
|
||||||
api = provider_spec.api
|
|
||||||
|
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||||
|
for api_str in apis_to_serve:
|
||||||
|
api = Api(api_str)
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
|
provider_spec = specs[api]
|
||||||
if (
|
if (
|
||||||
isinstance(provider_spec, RemoteProviderSpec)
|
isinstance(provider_spec, RemoteProviderSpec)
|
||||||
and provider_spec.adapter is None
|
and provider_spec.adapter is None
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
|
@ -26,16 +26,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
Api.memory,
|
||||||
adapter=AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_id="chromadb",
|
adapter_id="chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||||
module="llama_toolchain.memory.adapters.chroma",
|
module="llama_toolchain.memory.adapters.chroma",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
Api.memory,
|
||||||
adapter=AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_id="pgvector",
|
adapter_id="pgvector",
|
||||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||||
module="llama_toolchain.memory.adapters.pgvector",
|
module="llama_toolchain.memory.adapters.pgvector",
|
||||||
|
|
17
llama_toolchain/memory/router/__init__.py
Normal file
17
llama_toolchain/memory/router/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
||||||
|
from .router import MemoryRouterImpl
|
||||||
|
|
||||||
|
impl = MemoryRouterImpl(inner_impls, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
91
llama_toolchain/memory/router/router.py
Normal file
91
llama_toolchain/memory/router/router.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import Api
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRouterImpl(Memory):
|
||||||
|
"""Routes to an provider based on the memory bank type"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_impls: List[Tuple[str, Any]],
|
||||||
|
deps: List[Api],
|
||||||
|
) -> None:
|
||||||
|
self.deps = deps
|
||||||
|
|
||||||
|
bank_types = [v.value for v in MemoryBankType]
|
||||||
|
|
||||||
|
self.providers = {}
|
||||||
|
for routing_key, provider_impl in inner_impls:
|
||||||
|
if routing_key not in bank_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown routing key `{routing_key}` for memory bank type"
|
||||||
|
)
|
||||||
|
self.providers[routing_key] = provider_impl
|
||||||
|
|
||||||
|
self.bank_id_to_type = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for p in self.providers.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
|
def get_provider(self, bank_type):
|
||||||
|
if bank_type not in self.providers:
|
||||||
|
raise ValueError(f"Memory bank type {bank_type} not supported")
|
||||||
|
|
||||||
|
return self.providers[bank_type]
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
provider = self.get_provider(config.type)
|
||||||
|
bank = await provider.create_memory_bank(name, config, url)
|
||||||
|
self.bank_id_to_type[bank.bank_id] = config.type
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_type = self.bank_id_to_type.get(bank_id)
|
||||||
|
if not bank_type:
|
||||||
|
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||||
|
|
||||||
|
provider = self.get_provider(bank_type)
|
||||||
|
return await provider.get_memory_bank(bank_id)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
bank_type = self.bank_id_to_type.get(bank_id)
|
||||||
|
if not bank_type:
|
||||||
|
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||||
|
|
||||||
|
provider = self.get_provider(bank_type)
|
||||||
|
return await provider.insert_documents(bank_id, documents, ttl_seconds)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
bank_type = self.bank_id_to_type.get(bank_id)
|
||||||
|
if not bank_type:
|
||||||
|
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||||
|
|
||||||
|
provider = self.get_provider(bank_type)
|
||||||
|
return await provider.query_documents(bank_id, query, params)
|
Loading…
Add table
Add a link
Reference in a new issue