From cf8bd10989b39d081e27b46a8b0b4225cac72eb9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 21 Sep 2024 12:42:09 -0700 Subject: [PATCH] Revert "migrate router for memory wip" This reverts commit 08379f521449f8a208ed2b83ad39085f4063549e. --- llama_stack/distribution/routers/__init__.py | 5 - .../distribution/routers/memory/__init__.py | 17 ---- .../distribution/routers/memory/memory.py | 91 ------------------- llama_stack/distribution/server/server.py | 52 +---------- llama_stack/distribution/utils/dynamic.py | 1 - llama_stack/examples/router-table-run.yaml | 36 ++++---- .../adapters/memory/pgvector/pgvector.py | 53 ++++++----- .../impls/meta_reference/memory/faiss.py | 3 +- 8 files changed, 45 insertions(+), 213 deletions(-) delete mode 100644 llama_stack/distribution/routers/__init__.py delete mode 100644 llama_stack/distribution/routers/memory/__init__.py delete mode 100644 llama_stack/distribution/routers/memory/memory.py diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/distribution/routers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/distribution/routers/memory/__init__.py b/llama_stack/distribution/routers/memory/__init__.py deleted file mode 100644 index d4dbbb1d4..000000000 --- a/llama_stack/distribution/routers/memory/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, List, Tuple - -from llama_stack.distribution.datatypes import Api - - -async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): - from .memory import MemoryRouterImpl - - impl = MemoryRouterImpl(inner_impls, deps) - await impl.initialize() - return impl diff --git a/llama_stack/distribution/routers/memory/memory.py b/llama_stack/distribution/routers/memory/memory.py deleted file mode 100644 index b96cde626..000000000 --- a/llama_stack/distribution/routers/memory/memory.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, List, Tuple - -from llama_stack.distribution.datatypes import Api -from llama_stack.apis.memory import * # noqa: F403 - - -class MemoryRouterImpl(Memory): - """Routes to an provider based on the memory bank type""" - - def __init__( - self, - inner_impls: List[Tuple[str, Any]], - deps: List[Api], - ) -> None: - self.deps = deps - - bank_types = [v.value for v in MemoryBankType] - - self.providers = {} - for routing_key, provider_impl in inner_impls: - if routing_key not in bank_types: - raise ValueError( - f"Unknown routing key `{routing_key}` for memory bank type" - ) - self.providers[routing_key] = provider_impl - - self.bank_id_to_type = {} - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - for p in self.providers.values(): - await p.shutdown() - - def get_provider(self, bank_type): - if bank_type not in self.providers: - raise ValueError(f"Memory bank type {bank_type} not supported") - - return self.providers[bank_type] - - async def create_memory_bank( - self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - provider = self.get_provider(config.type) - bank = await provider.create_memory_bank(name, config, url) - self.bank_id_to_type[bank.bank_id] = config.type - return bank - - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - bank_type = self.bank_id_to_type.get(bank_id) - if not bank_type: - raise ValueError(f"Could not find bank type for {bank_id}") - - provider = self.get_provider(bank_type) - return await provider.get_memory_bank(bank_id) - - async def insert_documents( - self, - bank_id: str, - documents: List[MemoryBankDocument], - ttl_seconds: Optional[int] = None, - ) -> None: - bank_type = self.bank_id_to_type.get(bank_id) - if not bank_type: - raise ValueError(f"Could not find bank type for {bank_id}") - - provider = self.get_provider(bank_type) - return await provider.insert_documents(bank_id, documents, ttl_seconds) - - async def query_documents( - self, - bank_id: str, - query: InterleavedTextMedia, - params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - bank_type = self.bank_id_to_type.get(bank_id) - if not bank_type: - raise ValueError(f"Could not find bank type for {bank_id}") - - provider = self.get_provider(bank_type) - return await provider.query_documents(bank_id, query, params) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 57836caaa..4c9e97899 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -291,56 +291,7 @@ async def resolve_impls_with_routing( stack_run_config: StackRunConfig, ) -> Dict[Api, Any]: - all_providers = api_providers() - specs = {} - - for api_str in stack_run_config.apis_to_serve: - api = Api(api_str) - providers = all_providers[api] - - # check for regular providers without routing - if api_str in stack_run_config.provider_map: - provider_map_entry = stack_run_config.provider_map[api_str] - if provider_map_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" - ) - specs[api] = providers[provider_map_entry.provider_id] - - # check for routing table, we need to pass routing table to the router implementation - if api_str in stack_run_config.provider_routing_table: - router_entry = stack_run_config.provider_routing_table[api_str] - inner_specs = [] - for rt_entry in router_entry: - if rt_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" - ) - inner_specs.append(providers[rt_entry.provider_id]) - - specs[api] = RouterProviderSpec( - api=api, - module=f"llama_stack.distribution.routers.{api.value.lower()}", - api_dependencies=[], - inner_specs=inner_specs, - ) - - sorted_specs = topological_sort(specs.values()) - - impls = {} - for spec in sorted_specs: - api = spec.api - deps = {api: impls[api] for api in spec.api_dependencies} - if api.value in stack_run_config.provider_map: - provider_config = stack_run_config.provider_map[api.value] - elif api.value in stack_run_config.provider_routing_table: - provider_config = stack_run_config.provider_routing_table[api.value] - else: - raise ValueError(f"Cannot find provider_config for Api {api.value}") - impl = await instantiate_provider(spec, deps, provider_config) - impls[api] = impl - - return impls, specs + raise NotImplementedError("This is not implemented yet") async def resolve_impls( @@ -405,7 +356,6 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): if config.provider_routing_table is not None: impls, specs = asyncio.run(resolve_impls_with_routing(config)) else: - # keeping this for backwards compatibility,could impls, specs = asyncio.run(resolve_impls(config.provider_map)) if Api.telemetry in impls: diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 7e1e8d253..002a738ae 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -33,7 +33,6 @@ async def instantiate_provider( assert isinstance(provider_config, GenericProviderConfig) config_type = instantiate_class_type(provider_spec.config_class) - print("!!!", provider_config) config = config_type(**provider_config.config) args = [config, deps] elif isinstance(provider_spec, RouterProviderSpec): diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index 379ccebe3..a94c883cb 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -3,27 +3,26 @@ image_name: local docker_image: null conda_env: local apis_to_serve: -# - inference +- inference - memory -- telemetry provider_map: - telemetry: - provider_id: meta-reference - config: {} + # use builtin-router as dummy field + memory: builtin-router + inference: builtin-router provider_routing_table: - # inference: - # - routing_key: Meta-Llama3.1-8B-Instruct - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B-Instruct - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 - # - routing_key: Meta-Llama3.1-8B - # provider_id: remote::ollama - # config: - # url: http:ollama-url-1.com + inference: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + - routing_key: Meta-Llama3.1-8B + provider_id: remote::ollama + config: + url: http:ollama-url-1.com memory: - routing_key: keyvalue provider_id: remote::pgvector @@ -32,7 +31,6 @@ provider_routing_table: port: 5432 db: vectordb user: vectoruser - password: xxxx - routing_key: vector provider_id: meta-reference config: {} diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index 0907af2e3..a5c84a1b2 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -128,36 +128,35 @@ class PGVectorMemoryAdapter(Memory): self.cache = {} async def initialize(self) -> None: - print("Init PGVector!") - # try: - # self.conn = psycopg2.connect( - # host=self.config.host, - # port=self.config.port, - # database=self.config.db, - # user=self.config.user, - # password=self.config.password, - # ) - # self.cursor = self.conn.cursor() + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + ) + self.cursor = self.conn.cursor() - # version = check_extension_version(self.cursor) - # if version: - # print(f"Vector extension version: {version}") - # else: - # raise RuntimeError("Vector extension is not installed.") + version = check_extension_version(self.cursor) + if version: + print(f"Vector extension version: {version}") + else: + raise RuntimeError("Vector extension is not installed.") - # self.cursor.execute( - # """ - # CREATE TABLE IF NOT EXISTS metadata_store ( - # key TEXT PRIMARY KEY, - # data JSONB - # ) - # """ - # ) - # except Exception as e: - # import traceback + self.cursor.execute( + """ + CREATE TABLE IF NOT EXISTS metadata_store ( + key TEXT PRIMARY KEY, + data JSONB + ) + """ + ) + except Exception as e: + import traceback - # traceback.print_exc() - # raise RuntimeError("Could not connect to PGVector database server") from e + traceback.print_exc() + raise RuntimeError("Could not connect to PGVector database server") from e async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 0d31d40e9..ee716430e 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -68,8 +68,7 @@ class FaissMemoryImpl(Memory): self.config = config self.cache = {} - async def initialize(self) -> None: - print("INIT meta-reference") + async def initialize(self) -> None: ... async def shutdown(self) -> None: ...