From b6a3ef51dac417f7b1bbf839ba8da83daf86b00e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 16 Sep 2024 10:38:11 -0700 Subject: [PATCH] Introduce a "Router" layer for providers Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose. --- llama_toolchain/cli/stack/build.py | 4 +- llama_toolchain/cli/stack/configure.py | 13 +-- llama_toolchain/common/prompt_for_config.py | 25 +++++ llama_toolchain/core/configure.py | 102 ++++++++++++++------ llama_toolchain/core/datatypes.py | 64 ++++++++++-- llama_toolchain/core/dynamic.py | 40 ++++++-- llama_toolchain/core/package.py | 41 ++++---- llama_toolchain/core/server.py | 96 ++++++++++-------- llama_toolchain/memory/client.py | 1 + llama_toolchain/memory/providers.py | 8 +- llama_toolchain/memory/router/__init__.py | 17 ++++ llama_toolchain/memory/router/router.py | 91 +++++++++++++++++ 12 files changed, 384 insertions(+), 118 deletions(-) create mode 100644 llama_toolchain/memory/router/__init__.py create mode 100644 llama_toolchain/memory/router/router.py diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index 36cd480fc..78e013219 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -46,7 +46,7 @@ class StackBuild(Subcommand): from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR from llama_toolchain.common.serialize import EnumEncoder - from llama_toolchain.core.package import ApiInput, build_package, ImageType + from llama_toolchain.core.package import ApiInput, build_image, ImageType from termcolor import cprint # save build.yaml spec for building same distribution again @@ -66,7 +66,7 @@ class StackBuild(Subcommand): to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder)) f.write(yaml.dump(to_write, sort_keys=False)) - build_package(build_config, build_file_path) + build_image(build_config, build_file_path) cprint( f"Build spec configuration saved at {str(build_file_path)}", diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 4a73f1af4..53c9622e7 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -105,13 +105,6 @@ class StackConfigure(Subcommand): image_name = build_config.name.replace("::", "-") run_config_file = builds_dir / f"{image_name}-run.yaml" - api2providers = build_config.distribution_spec.providers - - stub_config = { - api_str: {"provider_id": provider} - for api_str, provider in api2providers.items() - } - if run_config_file.exists(): cprint( f"Configuration already exists for {build_config.name}. Will overwrite...", @@ -123,10 +116,12 @@ class StackConfigure(Subcommand): config = StackRunConfig( built_at=datetime.now(), image_name=image_name, - providers=stub_config, + apis_to_serve=[], + provider_map={}, ) - config.providers = configure_api_providers(config.providers) + config = configure_api_providers(config, build_config.distribution_spec) + config.docker_image = ( image_name if build_config.image_type == "docker" else None ) diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py index 4f92ec7d9..d9d778540 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_toolchain/common/prompt_for_config.py @@ -27,6 +27,12 @@ def is_list_of_primitives(field_type): return False +def is_basemodel_without_fields(typ): + return ( + inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0 + ) + + def can_recurse(typ): return ( inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0 @@ -151,6 +157,11 @@ def prompt_for_config( if get_origin(field_type) is Literal: continue + # Skip fields with no type annotations + if is_basemodel_without_fields(field_type): + config_data[field_name] = field_type() + continue + if inspect.isclass(field_type) and issubclass(field_type, Enum): prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):" while True: @@ -254,6 +265,20 @@ def prompt_for_config( print(f"{str(e)}") continue + elif get_origin(field_type) is dict: + try: + value = json.loads(user_input) + if not isinstance(value, dict): + raise ValueError( + "Input must be a JSON-encoded dictionary" + ) + + except json.JSONDecodeError: + print( + "Invalid JSON. Please enter a valid JSON-encoded dict." + ) + continue + # Convert the input to the correct type elif inspect.isclass(field_type) and issubclass( field_type, BaseModel diff --git a/llama_toolchain/core/configure.py b/llama_toolchain/core/configure.py index 7f9aa0140..0e9c41300 100644 --- a/llama_toolchain/core/configure.py +++ b/llama_toolchain/core/configure.py @@ -4,47 +4,87 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any + +from pydantic import BaseModel from llama_toolchain.core.datatypes import * # noqa: F403 from termcolor import cprint from llama_toolchain.common.prompt_for_config import prompt_for_config -from llama_toolchain.core.distribution import api_providers +from llama_toolchain.core.distribution import api_providers, stack_apis from llama_toolchain.core.dynamic import instantiate_class_type -def configure_api_providers(existing_configs: Dict[str, Any]) -> None: +# These are hacks so we can re-use the `prompt_for_config` utility +# This needs a bunch of work to be made very user friendly. +class ReqApis(BaseModel): + apis_to_serve: List[str] + + +def make_routing_entry_type(config_class: Any): + class BaseModelWithConfig(BaseModel): + routing_key: str + config: config_class + + return BaseModelWithConfig + + +# TODO: make sure we can deal with existing configuration values correctly +# instead of just overwriting them +def configure_api_providers( + config: StackRunConfig, spec: DistributionSpec +) -> StackRunConfig: + cprint("Configuring APIs to serve...", "white", attrs=["bold"]) + print("Enter comma-separated list of APIs to serve:") + + apis = config.apis_to_serve or list(spec.providers.keys()) + apis = [a for a in apis if a != "telemetry"] + req_apis = ReqApis( + apis_to_serve=apis, + ) + req_apis = prompt_for_config(ReqApis, req_apis) + print("") + + apis = [v.value for v in stack_apis()] all_providers = api_providers() - provider_configs = {} - for api_str, stub_config in existing_configs.items(): + apis_to_serve = req_apis.apis_to_serve + ["telemetry"] + for api_str in apis_to_serve: + if api_str not in apis: + raise ValueError(f"Unknown API `{api_str}`") + + cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"]) api = Api(api_str) - providers = all_providers[api] - provider_id = stub_config["provider_id"] - if provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api_str}`" + if isinstance(spec.providers[api_str], list): + print( + "You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n" + ) + routing_entries = [] + for p in spec.providers[api_str]: + print(f"Configuring provider `{p}`...") + provider_spec = all_providers[api][p] + config_type = instantiate_class_type(provider_spec.config_class) + + wrapper_type = make_routing_entry_type(config_type) + rt_entry = prompt_for_config(wrapper_type, None) + + # TODO: we need to validate the routing keys + routing_entries.append( + ProviderRoutingEntry( + provider_id=p, + routing_key=rt_entry.routing_key, + config=rt_entry.config.dict(), + ) + ) + config.provider_map[api_str] = routing_entries + else: + provider_spec = all_providers[api][spec.providers[api_str]] + config_type = instantiate_class_type(provider_spec.config_class) + cfg = prompt_for_config(config_type, None) + config.provider_map[api_str] = GenericProviderConfig( + provider_id=spec.providers[api_str], + config=cfg.dict(), ) - provider_spec = providers[provider_id] - cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"]) - config_type = instantiate_class_type(provider_spec.config_class) - - try: - existing_provider_config = config_type(**stub_config) - except Exception: - existing_provider_config = None - - provider_config = prompt_for_config( - config_type, - existing_provider_config, - ) - print("") - - provider_configs[api_str] = { - "provider_id": provider_id, - **provider_config.dict(), - } - - return provider_configs + return config diff --git a/llama_toolchain/core/datatypes.py b/llama_toolchain/core/datatypes.py index f523e0308..2821bf403 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_toolchain/core/datatypes.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from llama_models.schema_utils import json_schema_type @@ -43,6 +43,33 @@ class ProviderSpec(BaseModel): ) +@json_schema_type +class RouterProviderSpec(ProviderSpec): + provider_id: str = "router" + config_class: str = "" + + docker_image: Optional[str] = None + + inner_specs: List[ProviderSpec] + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have: + + - `get_router_impl(config, provider_specs, deps)`: returns the router implementation +""", + ) + + @property + def pip_packages(self) -> List[str]: + raise AssertionError("Should not be called on RouterProviderSpec") + + +class GenericProviderConfig(BaseModel): + provider_id: str + config: Dict[str, Any] + + @json_schema_type class AdapterSpec(BaseModel): adapter_id: str = Field( @@ -156,12 +183,23 @@ class DistributionSpec(BaseModel): description="Description of the distribution", ) docker_image: Optional[str] = None - providers: Dict[str, str] = Field( + providers: Dict[str, Union[str, List[str]]] = Field( default_factory=dict, - description="Provider Types for each of the APIs provided by this distribution", + description=""" +Provider Types for each of the APIs provided by this distribution. If you +select multiple providers, you should provide an appropriate 'routing_map' +in the runtime configuration to help route to the correct provider.""", ) +@json_schema_type +class ProviderRoutingEntry(GenericProviderConfig): + routing_key: str + + +ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]] + + @json_schema_type class StackRunConfig(BaseModel): built_at: datetime @@ -181,12 +219,22 @@ this could be just a hash default=None, description="Reference to the conda environment if this package refers to a conda environment", ) - providers: Dict[str, Any] = Field( - default_factory=dict, + apis_to_serve: List[str] = Field( description=""" -Provider configurations for each of the APIs provided by this package. This includes configurations for -the dependencies of these providers as well. -""", +The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", + ) + provider_map: Dict[str, ProviderMapEntry] = Field( + description=""" +Provider configurations for each of the APIs provided by this package. + +Given an API, you can specify a single provider or a "routing table". Each entry in the routing +table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific. + +As examples: +- the "inference" API interprets the routing_key as a "model" +- the "memory" API interprets the routing_key as the type of a "memory bank" + +The key may support wild-cards alsothe routing_key to route to the correct provider.""", ) diff --git a/llama_toolchain/core/dynamic.py b/llama_toolchain/core/dynamic.py index adb9b5dac..42c0646da 100644 --- a/llama_toolchain/core/dynamic.py +++ b/llama_toolchain/core/dynamic.py @@ -4,11 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import importlib from typing import Any, Dict -from .datatypes import ProviderSpec, RemoteProviderSpec +from llama_toolchain.core.datatypes import * # noqa: F403 def instantiate_class_type(fully_qualified_name): @@ -18,25 +17,50 @@ def instantiate_class_type(fully_qualified_name): # returns a class implementing the protocol corresponding to the Api -def instantiate_provider( +async def instantiate_provider( provider_spec: ProviderSpec, - provider_config: Dict[str, Any], - deps: Dict[str, ProviderSpec], + deps: Dict[str, Any], + provider_config: ProviderMapEntry, ): module = importlib.import_module(provider_spec.module) - config_type = instantiate_class_type(provider_spec.config_class) + args = [] if isinstance(provider_spec, RemoteProviderSpec): if provider_spec.adapter: method = "get_adapter_impl" else: method = "get_client_impl" + + assert isinstance(provider_config, GenericProviderConfig) + config_type = instantiate_class_type(provider_spec.config_class) + config = config_type(**provider_config.config) + args = [config, deps] + elif isinstance(provider_spec, RouterProviderSpec): + method = "get_router_impl" + + assert isinstance(provider_config, list) + inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} + inner_impls = [] + for routing_entry in provider_config: + impl = await instantiate_provider( + inner_specs[routing_entry.provider_id], + deps, + routing_entry, + ) + inner_impls.append((routing_entry.routing_key, impl)) + + config = None + args = [inner_impls, deps] else: method = "get_provider_impl" - config = config_type(**provider_config) + assert isinstance(provider_config, GenericProviderConfig) + config_type = instantiate_class_type(provider_spec.config_class) + config = config_type(**provider_config.config) + args = [config, deps] + fn = getattr(module, method) - impl = asyncio.run(fn(config, deps)) + impl = await fn(*args) impl.__provider_spec__ = provider_spec impl.__provider_config__ = config return impl diff --git a/llama_toolchain/core/package.py b/llama_toolchain/core/package.py index 7987384e2..37dac091d 100644 --- a/llama_toolchain/core/package.py +++ b/llama_toolchain/core/package.py @@ -4,22 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -import os -from datetime import datetime from enum import Enum from typing import List, Optional import pkg_resources -import yaml - -from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR -from llama_toolchain.common.exec import run_with_pty -from llama_toolchain.common.serialize import EnumEncoder from pydantic import BaseModel from termcolor import cprint +from llama_toolchain.common.exec import run_with_pty + from llama_toolchain.core.datatypes import * # noqa: F403 from pathlib import Path @@ -41,7 +35,7 @@ class ApiInput(BaseModel): provider: str -def build_package(build_config: BuildConfig, build_file_path: Path): +def build_image(build_config: BuildConfig, build_file_path: Path): package_deps = Dependencies( docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim", pip_packages=SERVER_DEPENDENCIES, @@ -49,17 +43,28 @@ def build_package(build_config: BuildConfig, build_file_path: Path): # extend package dependencies based on providers spec all_providers = api_providers() - for api_str, provider in build_config.distribution_spec.providers.items(): + for ( + api_str, + provider_or_providers, + ) in build_config.distribution_spec.providers.items(): providers_for_api = all_providers[Api(api_str)] - if provider not in providers_for_api: - raise ValueError( - f"Provider `{provider}` is not available for API `{api_str}`" - ) - provider_spec = providers_for_api[provider] - package_deps.pip_packages.extend(provider_spec.pip_packages) - if provider_spec.docker_image: - raise ValueError("A stack's dependencies cannot have a docker image") + providers = ( + provider_or_providers + if isinstance(provider_or_providers, list) + else [provider_or_providers] + ) + + for provider in providers: + if provider not in providers_for_api: + raise ValueError( + f"Provider `{provider}` is not available for API `{api_str}`" + ) + + provider_spec = providers_for_api[provider] + package_deps.pip_packages.extend(provider_spec.pip_packages) + if provider_spec.docker_image: + raise ValueError("A stack's dependencies cannot have a docker image") if build_config.image_type == ImageType.docker.value: script = pkg_resources.resource_filename( diff --git a/llama_toolchain/core/server.py b/llama_toolchain/core/server.py index 7082ec765..70273be16 100644 --- a/llama_toolchain/core/server.py +++ b/llama_toolchain/core/server.py @@ -9,6 +9,7 @@ import inspect import json import signal import traceback + from collections.abc import ( AsyncGenerator as AsyncGeneratorABC, AsyncIterator as AsyncIteratorABC, @@ -44,8 +45,8 @@ from llama_toolchain.telemetry.tracing import ( SpanStatus, start_trace, ) +from llama_toolchain.core.datatypes import * # noqa: F403 -from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec from .distribution import api_endpoints, api_providers from .dynamic import instantiate_provider @@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: return [by_id[x] for x in stack] -def resolve_impls( - provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any] -) -> Dict[Api, Any]: - provider_configs = config["providers"] - provider_specs = topological_sort(provider_specs.values()) +def snake_to_camel(snake_str): + return "".join(word.capitalize() for word in snake_str.split("_")) - impls = {} - for provider_spec in provider_specs: - api = provider_spec.api - if api.value not in provider_configs: - raise ValueError( - f"Could not find provider_spec config for {api}. Please add it to the config" + +async def resolve_impls( + provider_map: Dict[str, ProviderMapEntry], +) -> Dict[Api, Any]: + """ + Does two things: + - flatmaps, sorts and resolves the providers in dependency order + - for each API, produces either a (local, passthrough or router) implementation + """ + all_providers = api_providers() + + specs = {} + for api_str, item in provider_map.items(): + api = Api(api_str) + providers = all_providers[api] + + if isinstance(item, GenericProviderConfig): + if item.provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api}`" + ) + specs[api] = providers[item.provider_id] + else: + assert isinstance(item, list) + inner_specs = [] + for rt_entry in item: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) + + specs[api] = RouterProviderSpec( + api=api, + module=f"llama_toolchain.{api.value.lower()}.router", + api_dependencies=[], + inner_specs=inner_specs, ) - if isinstance(provider_spec, InlineProviderSpec): - deps = {api: impls[api] for api in provider_spec.api_dependencies} - else: - deps = {} - provider_config = provider_configs[api.value] - impl = instantiate_provider(provider_spec, provider_config, deps) + sorted_specs = topological_sort(specs.values()) + + impls = {} + for spec in sorted_specs: + api = spec.api + + deps = {api: impls[api] for api in spec.api_dependencies} + impl = await instantiate_provider(spec, deps, provider_map[api.value]) impls[api] = impl - return impls + return impls, specs def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: - config = yaml.safe_load(fp) + config = StackRunConfig(**yaml.safe_load(fp)) app = FastAPI() - all_endpoints = api_endpoints() - all_providers = api_providers() - - provider_specs = {} - for api_str, provider_config in config["providers"].items(): - api = Api(api_str) - providers = all_providers[api] - provider_id = provider_config["provider_id"] - if provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" - ) - - provider_specs[api] = providers[provider_id] - - impls = resolve_impls(provider_specs, config) + impls, specs = asyncio.run(resolve_impls(config.provider_map)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) - for provider_spec in provider_specs.values(): - api = provider_spec.api + all_endpoints = api_endpoints() + + apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) + for api_str in apis_to_serve: + api = Api(api_str) endpoints = all_endpoints[api] impl = impls[api] + provider_spec = specs[api] if ( isinstance(provider_spec, RemoteProviderSpec) and provider_spec.adapter is None diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index 5f74219da..c2c04b213 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -6,6 +6,7 @@ import asyncio import json +import os from pathlib import Path from typing import Any, Dict, List, Optional diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index d3336278a..cf443f5f3 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -26,16 +26,16 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig", ), remote_provider_spec( - api=Api.memory, - adapter=AdapterSpec( + Api.memory, + AdapterSpec( adapter_id="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_toolchain.memory.adapters.chroma", ), ), remote_provider_spec( - api=Api.memory, - adapter=AdapterSpec( + Api.memory, + AdapterSpec( adapter_id="pgvector", pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], module="llama_toolchain.memory.adapters.pgvector", diff --git a/llama_toolchain/memory/router/__init__.py b/llama_toolchain/memory/router/__init__.py new file mode 100644 index 000000000..25c5ac2a8 --- /dev/null +++ b/llama_toolchain/memory/router/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, List, Tuple + +from llama_toolchain.core.datatypes import Api + + +async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): + from .router import MemoryRouterImpl + + impl = MemoryRouterImpl(inner_impls, deps) + await impl.initialize() + return impl diff --git a/llama_toolchain/memory/router/router.py b/llama_toolchain/memory/router/router.py new file mode 100644 index 000000000..b415fbb96 --- /dev/null +++ b/llama_toolchain/memory/router/router.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from llama_toolchain.core.datatypes import Api +from llama_toolchain.memory.api import * # noqa: F403 + + +class MemoryRouterImpl(Memory): + """Routes to an provider based on the memory bank type""" + + def __init__( + self, + inner_impls: List[Tuple[str, Any]], + deps: List[Api], + ) -> None: + self.deps = deps + + bank_types = [v.value for v in MemoryBankType] + + self.providers = {} + for routing_key, provider_impl in inner_impls: + if routing_key not in bank_types: + raise ValueError( + f"Unknown routing key `{routing_key}` for memory bank type" + ) + self.providers[routing_key] = provider_impl + + self.bank_id_to_type = {} + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + for p in self.providers.values(): + await p.shutdown() + + def get_provider(self, bank_type): + if bank_type not in self.providers: + raise ValueError(f"Memory bank type {bank_type} not supported") + + return self.providers[bank_type] + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + provider = self.get_provider(config.type) + bank = await provider.create_memory_bank(name, config, url) + self.bank_id_to_type[bank.bank_id] = config.type + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.get_memory_bank(bank_id) + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.insert_documents(bank_id, documents, ttl_seconds) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.query_documents(bank_id, query, params)