mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Introduce a "Router" layer for providers
Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose.
This commit is contained in:
parent
5c1f2616b5
commit
b6a3ef51da
12 changed files with 384 additions and 118 deletions
|
@ -46,7 +46,7 @@ class StackBuild(Subcommand):
|
|||
|
||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
from llama_toolchain.core.package import ApiInput, build_package, ImageType
|
||||
from llama_toolchain.core.package import ApiInput, build_image, ImageType
|
||||
from termcolor import cprint
|
||||
|
||||
# save build.yaml spec for building same distribution again
|
||||
|
@ -66,7 +66,7 @@ class StackBuild(Subcommand):
|
|||
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
build_package(build_config, build_file_path)
|
||||
build_image(build_config, build_file_path)
|
||||
|
||||
cprint(
|
||||
f"Build spec configuration saved at {str(build_file_path)}",
|
||||
|
|
|
@ -105,13 +105,6 @@ class StackConfigure(Subcommand):
|
|||
image_name = build_config.name.replace("::", "-")
|
||||
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
||||
|
||||
api2providers = build_config.distribution_spec.providers
|
||||
|
||||
stub_config = {
|
||||
api_str: {"provider_id": provider}
|
||||
for api_str, provider in api2providers.items()
|
||||
}
|
||||
|
||||
if run_config_file.exists():
|
||||
cprint(
|
||||
f"Configuration already exists for {build_config.name}. Will overwrite...",
|
||||
|
@ -123,10 +116,12 @@ class StackConfigure(Subcommand):
|
|||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
image_name=image_name,
|
||||
providers=stub_config,
|
||||
apis_to_serve=[],
|
||||
provider_map={},
|
||||
)
|
||||
|
||||
config.providers = configure_api_providers(config.providers)
|
||||
config = configure_api_providers(config, build_config.distribution_spec)
|
||||
|
||||
config.docker_image = (
|
||||
image_name if build_config.image_type == "docker" else None
|
||||
)
|
||||
|
|
|
@ -27,6 +27,12 @@ def is_list_of_primitives(field_type):
|
|||
return False
|
||||
|
||||
|
||||
def is_basemodel_without_fields(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||
)
|
||||
|
||||
|
||||
def can_recurse(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
|
@ -151,6 +157,11 @@ def prompt_for_config(
|
|||
if get_origin(field_type) is Literal:
|
||||
continue
|
||||
|
||||
# Skip fields with no type annotations
|
||||
if is_basemodel_without_fields(field_type):
|
||||
config_data[field_name] = field_type()
|
||||
continue
|
||||
|
||||
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
||||
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
||||
while True:
|
||||
|
@ -254,6 +265,20 @@ def prompt_for_config(
|
|||
print(f"{str(e)}")
|
||||
continue
|
||||
|
||||
elif get_origin(field_type) is dict:
|
||||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(
|
||||
"Input must be a JSON-encoded dictionary"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert the input to the correct type
|
||||
elif inspect.isclass(field_type) and issubclass(
|
||||
field_type, BaseModel
|
||||
|
|
|
@ -4,47 +4,87 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||
from llama_toolchain.core.distribution import api_providers
|
||||
from llama_toolchain.core.distribution import api_providers, stack_apis
|
||||
from llama_toolchain.core.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
|
||||
# These are hacks so we can re-use the `prompt_for_config` utility
|
||||
# This needs a bunch of work to be made very user friendly.
|
||||
class ReqApis(BaseModel):
|
||||
apis_to_serve: List[str]
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
config: config_class
|
||||
|
||||
return BaseModelWithConfig
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
|
||||
print("Enter comma-separated list of APIs to serve:")
|
||||
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
apis = [a for a in apis if a != "telemetry"]
|
||||
req_apis = ReqApis(
|
||||
apis_to_serve=apis,
|
||||
)
|
||||
req_apis = prompt_for_config(ReqApis, req_apis)
|
||||
print("")
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = api_providers()
|
||||
|
||||
provider_configs = {}
|
||||
for api_str, stub_config in existing_configs.items():
|
||||
apis_to_serve = req_apis.apis_to_serve + ["telemetry"]
|
||||
for api_str in apis_to_serve:
|
||||
if api_str not in apis:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = stub_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
|
||||
if isinstance(spec.providers[api_str], list):
|
||||
print(
|
||||
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
|
||||
)
|
||||
routing_entries = []
|
||||
for p in spec.providers[api_str]:
|
||||
print(f"Configuring provider `{p}`...")
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
wrapper_type = make_routing_entry_type(config_type)
|
||||
rt_entry = prompt_for_config(wrapper_type, None)
|
||||
|
||||
# TODO: we need to validate the routing keys
|
||||
routing_entries.append(
|
||||
ProviderRoutingEntry(
|
||||
provider_id=p,
|
||||
routing_key=rt_entry.routing_key,
|
||||
config=rt_entry.config.dict(),
|
||||
)
|
||||
)
|
||||
config.provider_map[api_str] = routing_entries
|
||||
else:
|
||||
provider_spec = all_providers[api][spec.providers[api_str]]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
cfg = prompt_for_config(config_type, None)
|
||||
config.provider_map[api_str] = GenericProviderConfig(
|
||||
provider_id=spec.providers[api_str],
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
provider_spec = providers[provider_id]
|
||||
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
try:
|
||||
existing_provider_config = config_type(**stub_config)
|
||||
except Exception:
|
||||
existing_provider_config = None
|
||||
|
||||
provider_config = prompt_for_config(
|
||||
config_type,
|
||||
existing_provider_config,
|
||||
)
|
||||
print("")
|
||||
|
||||
provider_configs[api_str] = {
|
||||
"provider_id": provider_id,
|
||||
**provider_config.dict(),
|
||||
}
|
||||
|
||||
return provider_configs
|
||||
return config
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
|
@ -43,6 +43,33 @@ class ProviderSpec(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouterProviderSpec(ProviderSpec):
|
||||
provider_id: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_id: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
|
@ -156,12 +183,23 @@ class DistributionSpec(BaseModel):
|
|||
description="Description of the distribution",
|
||||
)
|
||||
docker_image: Optional[str] = None
|
||||
providers: Dict[str, str] = Field(
|
||||
providers: Dict[str, Union[str, List[str]]] = Field(
|
||||
default_factory=dict,
|
||||
description="Provider Types for each of the APIs provided by this distribution",
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderRoutingEntry(GenericProviderConfig):
|
||||
routing_key: str
|
||||
|
||||
|
||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StackRunConfig(BaseModel):
|
||||
built_at: datetime
|
||||
|
@ -181,12 +219,22 @@ this could be just a hash
|
|||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
providers: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
apis_to_serve: List[str] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
||||
the dependencies of these providers as well.
|
||||
""",
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
provider_map: Dict[str, ProviderMapEntry] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
|
||||
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
|
||||
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
|
||||
|
||||
As examples:
|
||||
- the "inference" API interprets the routing_key as a "model"
|
||||
- the "memory" API interprets the routing_key as the type of a "memory bank"
|
||||
|
||||
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,11 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import ProviderSpec, RemoteProviderSpec
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
|
@ -18,25 +17,50 @@ def instantiate_class_type(fully_qualified_name):
|
|||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
def instantiate_provider(
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
provider_config: Dict[str, Any],
|
||||
deps: Dict[str, ProviderSpec],
|
||||
deps: Dict[str, Any],
|
||||
provider_config: ProviderMapEntry,
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, RouterProviderSpec):
|
||||
method = "get_router_impl"
|
||||
|
||||
assert isinstance(provider_config, list)
|
||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in provider_config:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_id],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [inner_impls, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
config = config_type(**provider_config)
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = asyncio.run(fn(config, deps))
|
||||
impl = await fn(*args)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
@ -4,22 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_toolchain.common.exec import run_with_pty
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
from pydantic import BaseModel
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.exec import run_with_pty
|
||||
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -41,7 +35,7 @@ class ApiInput(BaseModel):
|
|||
provider: str
|
||||
|
||||
|
||||
def build_package(build_config: BuildConfig, build_file_path: Path):
|
||||
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||
package_deps = Dependencies(
|
||||
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
||||
pip_packages=SERVER_DEPENDENCIES,
|
||||
|
@ -49,17 +43,28 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
|
|||
|
||||
# extend package dependencies based on providers spec
|
||||
all_providers = api_providers()
|
||||
for api_str, provider in build_config.distribution_spec.providers.items():
|
||||
for (
|
||||
api_str,
|
||||
provider_or_providers,
|
||||
) in build_config.distribution_spec.providers.items():
|
||||
providers_for_api = all_providers[Api(api_str)]
|
||||
if provider not in providers_for_api:
|
||||
raise ValueError(
|
||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||
)
|
||||
|
||||
provider_spec = providers_for_api[provider]
|
||||
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
||||
if provider_spec.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
providers = (
|
||||
provider_or_providers
|
||||
if isinstance(provider_or_providers, list)
|
||||
else [provider_or_providers]
|
||||
)
|
||||
|
||||
for provider in providers:
|
||||
if provider not in providers_for_api:
|
||||
raise ValueError(
|
||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||
)
|
||||
|
||||
provider_spec = providers_for_api[provider]
|
||||
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
||||
if provider_spec.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
|
@ -9,6 +9,7 @@ import inspect
|
|||
import json
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
|
@ -44,8 +45,8 @@ from llama_toolchain.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
|
||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
||||
|
@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def resolve_impls(
|
||||
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
||||
) -> Dict[Api, Any]:
|
||||
provider_configs = config["providers"]
|
||||
provider_specs = topological_sort(provider_specs.values())
|
||||
def snake_to_camel(snake_str):
|
||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
impls = {}
|
||||
for provider_spec in provider_specs:
|
||||
api = provider_spec.api
|
||||
if api.value not in provider_configs:
|
||||
raise ValueError(
|
||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||
|
||||
async def resolve_impls(
|
||||
provider_map: Dict[str, ProviderMapEntry],
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = api_providers()
|
||||
|
||||
specs = {}
|
||||
for api_str, item in provider_map.items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
|
||||
if isinstance(item, GenericProviderConfig):
|
||||
if item.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[item.provider_id]
|
||||
else:
|
||||
assert isinstance(item, list)
|
||||
inner_specs = []
|
||||
for rt_entry in item:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_toolchain.{api.value.lower()}.router",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
)
|
||||
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
else:
|
||||
deps = {}
|
||||
provider_config = provider_configs[api.value]
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
return impls, specs
|
||||
|
||||
|
||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
all_providers = api_providers()
|
||||
|
||||
provider_specs = {}
|
||||
for api_str, provider_config in config["providers"].items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = provider_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
provider_specs[api] = providers[provider_id]
|
||||
|
||||
impls = resolve_impls(provider_specs, config)
|
||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
for provider_spec in provider_specs.values():
|
||||
api = provider_spec.api
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
|
|
@ -26,16 +26,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
adapter=AdapterSpec(
|
||||
Api.memory,
|
||||
AdapterSpec(
|
||||
adapter_id="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
module="llama_toolchain.memory.adapters.chroma",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
adapter=AdapterSpec(
|
||||
Api.memory,
|
||||
AdapterSpec(
|
||||
adapter_id="pgvector",
|
||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||
module="llama_toolchain.memory.adapters.pgvector",
|
||||
|
|
17
llama_toolchain/memory/router/__init__.py
Normal file
17
llama_toolchain/memory/router/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from llama_toolchain.core.datatypes import Api
|
||||
|
||||
|
||||
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
||||
from .router import MemoryRouterImpl
|
||||
|
||||
impl = MemoryRouterImpl(inner_impls, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
91
llama_toolchain/memory/router/router.py
Normal file
91
llama_toolchain/memory/router/router.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from llama_toolchain.core.datatypes import Api
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouterImpl(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
deps: List[Api],
|
||||
) -> None:
|
||||
self.deps = deps
|
||||
|
||||
bank_types = [v.value for v in MemoryBankType]
|
||||
|
||||
self.providers = {}
|
||||
for routing_key, provider_impl in inner_impls:
|
||||
if routing_key not in bank_types:
|
||||
raise ValueError(
|
||||
f"Unknown routing key `{routing_key}` for memory bank type"
|
||||
)
|
||||
self.providers[routing_key] = provider_impl
|
||||
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.providers.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider(self, bank_type):
|
||||
if bank_type not in self.providers:
|
||||
raise ValueError(f"Memory bank type {bank_type} not supported")
|
||||
|
||||
return self.providers[bank_type]
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
provider = self.get_provider(config.type)
|
||||
bank = await provider.create_memory_bank(name, config, url)
|
||||
self.bank_id_to_type[bank.bank_id] = config.type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.get_provider(bank_type)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.get_provider(bank_type)
|
||||
return await provider.insert_documents(bank_id, documents, ttl_seconds)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.get_provider(bank_type)
|
||||
return await provider.query_documents(bank_id, query, params)
|
Loading…
Add table
Add a link
Reference in a new issue