From 08379f521449f8a208ed2b83ad39085f4063549e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 20 Sep 2024 12:19:33 -0700 Subject: [PATCH] migrate router for memory wip --- llama_stack/distribution/routers/__init__.py | 5 + .../distribution/routers/memory/__init__.py | 17 ++++ .../distribution/routers/memory/memory.py | 91 +++++++++++++++++++ llama_stack/distribution/server/server.py | 52 ++++++++++- llama_stack/distribution/utils/dynamic.py | 1 + llama_stack/examples/router-table-run.yaml | 36 ++++---- .../adapters/memory/pgvector/pgvector.py | 53 +++++------ .../impls/meta_reference/memory/faiss.py | 3 +- 8 files changed, 213 insertions(+), 45 deletions(-) create mode 100644 llama_stack/distribution/routers/__init__.py create mode 100644 llama_stack/distribution/routers/memory/__init__.py create mode 100644 llama_stack/distribution/routers/memory/memory.py diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/routers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/routers/memory/__init__.py b/llama_stack/distribution/routers/memory/__init__.py new file mode 100644 index 000000000..d4dbbb1d4 --- /dev/null +++ b/llama_stack/distribution/routers/memory/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, List, Tuple + +from llama_stack.distribution.datatypes import Api + + +async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): + from .memory import MemoryRouterImpl + + impl = MemoryRouterImpl(inner_impls, deps) + await impl.initialize() + return impl diff --git a/llama_stack/distribution/routers/memory/memory.py b/llama_stack/distribution/routers/memory/memory.py new file mode 100644 index 000000000..b96cde626 --- /dev/null +++ b/llama_stack/distribution/routers/memory/memory.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from llama_stack.distribution.datatypes import Api +from llama_stack.apis.memory import * # noqa: F403 + + +class MemoryRouterImpl(Memory): + """Routes to an provider based on the memory bank type""" + + def __init__( + self, + inner_impls: List[Tuple[str, Any]], + deps: List[Api], + ) -> None: + self.deps = deps + + bank_types = [v.value for v in MemoryBankType] + + self.providers = {} + for routing_key, provider_impl in inner_impls: + if routing_key not in bank_types: + raise ValueError( + f"Unknown routing key `{routing_key}` for memory bank type" + ) + self.providers[routing_key] = provider_impl + + self.bank_id_to_type = {} + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + for p in self.providers.values(): + await p.shutdown() + + def get_provider(self, bank_type): + if bank_type not in self.providers: + raise ValueError(f"Memory bank type {bank_type} not supported") + + return self.providers[bank_type] + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + provider = self.get_provider(config.type) + bank = await provider.create_memory_bank(name, config, url) + self.bank_id_to_type[bank.bank_id] = config.type + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.get_memory_bank(bank_id) + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.insert_documents(bank_id, documents, ttl_seconds) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.query_documents(bank_id, query, params) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 4c9e97899..57836caaa 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -291,7 +291,56 @@ async def resolve_impls_with_routing( stack_run_config: StackRunConfig, ) -> Dict[Api, Any]: - raise NotImplementedError("This is not implemented yet") + all_providers = api_providers() + specs = {} + + for api_str in stack_run_config.apis_to_serve: + api = Api(api_str) + providers = all_providers[api] + + # check for regular providers without routing + if api_str in stack_run_config.provider_map: + provider_map_entry = stack_run_config.provider_map[api_str] + if provider_map_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api}`" + ) + specs[api] = providers[provider_map_entry.provider_id] + + # check for routing table, we need to pass routing table to the router implementation + if api_str in stack_run_config.provider_routing_table: + router_entry = stack_run_config.provider_routing_table[api_str] + inner_specs = [] + for rt_entry in router_entry: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) + + specs[api] = RouterProviderSpec( + api=api, + module=f"llama_stack.distribution.routers.{api.value.lower()}", + api_dependencies=[], + inner_specs=inner_specs, + ) + + sorted_specs = topological_sort(specs.values()) + + impls = {} + for spec in sorted_specs: + api = spec.api + deps = {api: impls[api] for api in spec.api_dependencies} + if api.value in stack_run_config.provider_map: + provider_config = stack_run_config.provider_map[api.value] + elif api.value in stack_run_config.provider_routing_table: + provider_config = stack_run_config.provider_routing_table[api.value] + else: + raise ValueError(f"Cannot find provider_config for Api {api.value}") + impl = await instantiate_provider(spec, deps, provider_config) + impls[api] = impl + + return impls, specs async def resolve_impls( @@ -356,6 +405,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): if config.provider_routing_table is not None: impls, specs = asyncio.run(resolve_impls_with_routing(config)) else: + # keeping this for backwards compatibility,could impls, specs = asyncio.run(resolve_impls(config.provider_map)) if Api.telemetry in impls: diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 002a738ae..7e1e8d253 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -33,6 +33,7 @@ async def instantiate_provider( assert isinstance(provider_config, GenericProviderConfig) config_type = instantiate_class_type(provider_spec.config_class) + print("!!!", provider_config) config = config_type(**provider_config.config) args = [config, deps] elif isinstance(provider_spec, RouterProviderSpec): diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index a94c883cb..379ccebe3 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -3,26 +3,27 @@ image_name: local docker_image: null conda_env: local apis_to_serve: -- inference +# - inference - memory +- telemetry provider_map: - # use builtin-router as dummy field - memory: builtin-router - inference: builtin-router + telemetry: + provider_id: meta-reference + config: {} provider_routing_table: - inference: - - routing_key: Meta-Llama3.1-8B-Instruct - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - - routing_key: Meta-Llama3.1-8B - provider_id: remote::ollama - config: - url: http:ollama-url-1.com + # inference: + # - routing_key: Meta-Llama3.1-8B-Instruct + # provider_id: meta-reference + # config: + # model: Meta-Llama3.1-8B-Instruct + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 + # - routing_key: Meta-Llama3.1-8B + # provider_id: remote::ollama + # config: + # url: http:ollama-url-1.com memory: - routing_key: keyvalue provider_id: remote::pgvector @@ -31,6 +32,7 @@ provider_routing_table: port: 5432 db: vectordb user: vectoruser + password: xxxx - routing_key: vector provider_id: meta-reference config: {} diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index a5c84a1b2..0907af2e3 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -128,35 +128,36 @@ class PGVectorMemoryAdapter(Memory): self.cache = {} async def initialize(self) -> None: - try: - self.conn = psycopg2.connect( - host=self.config.host, - port=self.config.port, - database=self.config.db, - user=self.config.user, - password=self.config.password, - ) - self.cursor = self.conn.cursor() + print("Init PGVector!") + # try: + # self.conn = psycopg2.connect( + # host=self.config.host, + # port=self.config.port, + # database=self.config.db, + # user=self.config.user, + # password=self.config.password, + # ) + # self.cursor = self.conn.cursor() - version = check_extension_version(self.cursor) - if version: - print(f"Vector extension version: {version}") - else: - raise RuntimeError("Vector extension is not installed.") + # version = check_extension_version(self.cursor) + # if version: + # print(f"Vector extension version: {version}") + # else: + # raise RuntimeError("Vector extension is not installed.") - self.cursor.execute( - """ - CREATE TABLE IF NOT EXISTS metadata_store ( - key TEXT PRIMARY KEY, - data JSONB - ) - """ - ) - except Exception as e: - import traceback + # self.cursor.execute( + # """ + # CREATE TABLE IF NOT EXISTS metadata_store ( + # key TEXT PRIMARY KEY, + # data JSONB + # ) + # """ + # ) + # except Exception as e: + # import traceback - traceback.print_exc() - raise RuntimeError("Could not connect to PGVector database server") from e + # traceback.print_exc() + # raise RuntimeError("Could not connect to PGVector database server") from e async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index ee716430e..0d31d40e9 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -68,7 +68,8 @@ class FaissMemoryImpl(Memory): self.config = config self.cache = {} - async def initialize(self) -> None: ... + async def initialize(self) -> None: + print("INIT meta-reference") async def shutdown(self) -> None: ...