Revert "migrate router for memory wip"

This reverts commit 08379f5214.
This commit is contained in:
Xi Yan 2024-09-21 12:42:09 -07:00
parent 3939611676
commit cf8bd10989
8 changed files with 45 additions and 213 deletions

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
from .memory import MemoryRouterImpl
impl = MemoryRouterImpl(inner_impls, deps)
await impl.initialize()
return impl

View file

@ -1,91 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
from llama_stack.apis.memory import * # noqa: F403
class MemoryRouterImpl(Memory):
"""Routes to an provider based on the memory bank type"""
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
deps: List[Api],
) -> None:
self.deps = deps
bank_types = [v.value for v in MemoryBankType]
self.providers = {}
for routing_key, provider_impl in inner_impls:
if routing_key not in bank_types:
raise ValueError(
f"Unknown routing key `{routing_key}` for memory bank type"
)
self.providers[routing_key] = provider_impl
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
def get_provider(self, bank_type):
if bank_type not in self.providers:
raise ValueError(f"Memory bank type {bank_type} not supported")
return self.providers[bank_type]
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
provider = self.get_provider(config.type)
bank = await provider.create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = config.type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.insert_documents(bank_id, documents, ttl_seconds)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.query_documents(bank_id, query, params)

View file

@ -291,56 +291,7 @@ async def resolve_impls_with_routing(
stack_run_config: StackRunConfig, stack_run_config: StackRunConfig,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
all_providers = api_providers() raise NotImplementedError("This is not implemented yet")
specs = {}
for api_str in stack_run_config.apis_to_serve:
api = Api(api_str)
providers = all_providers[api]
# check for regular providers without routing
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
if provider_map_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_map_entry.provider_id]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
router_entry = stack_run_config.provider_routing_table[api_str]
inner_specs = []
for rt_entry in router_entry:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.distribution.routers.{api.value.lower()}",
api_dependencies=[],
inner_specs=inner_specs,
)
sorted_specs = topological_sort(specs.values())
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
if api.value in stack_run_config.provider_map:
provider_config = stack_run_config.provider_map[api.value]
elif api.value in stack_run_config.provider_routing_table:
provider_config = stack_run_config.provider_routing_table[api.value]
else:
raise ValueError(f"Cannot find provider_config for Api {api.value}")
impl = await instantiate_provider(spec, deps, provider_config)
impls[api] = impl
return impls, specs
async def resolve_impls( async def resolve_impls(
@ -405,7 +356,6 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
if config.provider_routing_table is not None: if config.provider_routing_table is not None:
impls, specs = asyncio.run(resolve_impls_with_routing(config)) impls, specs = asyncio.run(resolve_impls_with_routing(config))
else: else:
# keeping this for backwards compatibility,could
impls, specs = asyncio.run(resolve_impls(config.provider_map)) impls, specs = asyncio.run(resolve_impls(config.provider_map))
if Api.telemetry in impls: if Api.telemetry in impls:

View file

@ -33,7 +33,6 @@ async def instantiate_provider(
assert isinstance(provider_config, GenericProviderConfig) assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
print("!!!", provider_config)
config = config_type(**provider_config.config) config = config_type(**provider_config.config)
args = [config, deps] args = [config, deps]
elif isinstance(provider_spec, RouterProviderSpec): elif isinstance(provider_spec, RouterProviderSpec):

View file

@ -3,27 +3,26 @@ image_name: local
docker_image: null docker_image: null
conda_env: local conda_env: local
apis_to_serve: apis_to_serve:
# - inference - inference
- memory - memory
- telemetry
provider_map: provider_map:
telemetry: # use builtin-router as dummy field
provider_id: meta-reference memory: builtin-router
config: {} inference: builtin-router
provider_routing_table: provider_routing_table:
# inference: inference:
# - routing_key: Meta-Llama3.1-8B-Instruct - routing_key: Meta-Llama3.1-8B-Instruct
# provider_id: meta-reference provider_id: meta-reference
# config: config:
# model: Meta-Llama3.1-8B-Instruct model: Meta-Llama3.1-8B-Instruct
# quantization: null quantization: null
# torch_seed: null torch_seed: null
# max_seq_len: 4096 max_seq_len: 4096
# max_batch_size: 1 max_batch_size: 1
# - routing_key: Meta-Llama3.1-8B - routing_key: Meta-Llama3.1-8B
# provider_id: remote::ollama provider_id: remote::ollama
# config: config:
# url: http:ollama-url-1.com url: http:ollama-url-1.com
memory: memory:
- routing_key: keyvalue - routing_key: keyvalue
provider_id: remote::pgvector provider_id: remote::pgvector
@ -32,7 +31,6 @@ provider_routing_table:
port: 5432 port: 5432
db: vectordb db: vectordb
user: vectoruser user: vectoruser
password: xxxx
- routing_key: vector - routing_key: vector
provider_id: meta-reference provider_id: meta-reference
config: {} config: {}

View file

@ -128,36 +128,35 @@ class PGVectorMemoryAdapter(Memory):
self.cache = {} self.cache = {}
async def initialize(self) -> None: async def initialize(self) -> None:
print("Init PGVector!") try:
# try: self.conn = psycopg2.connect(
# self.conn = psycopg2.connect( host=self.config.host,
# host=self.config.host, port=self.config.port,
# port=self.config.port, database=self.config.db,
# database=self.config.db, user=self.config.user,
# user=self.config.user, password=self.config.password,
# password=self.config.password, )
# ) self.cursor = self.conn.cursor()
# self.cursor = self.conn.cursor()
# version = check_extension_version(self.cursor) version = check_extension_version(self.cursor)
# if version: if version:
# print(f"Vector extension version: {version}") print(f"Vector extension version: {version}")
# else: else:
# raise RuntimeError("Vector extension is not installed.") raise RuntimeError("Vector extension is not installed.")
# self.cursor.execute( self.cursor.execute(
# """ """
# CREATE TABLE IF NOT EXISTS metadata_store ( CREATE TABLE IF NOT EXISTS metadata_store (
# key TEXT PRIMARY KEY, key TEXT PRIMARY KEY,
# data JSONB data JSONB
# ) )
# """ """
# ) )
# except Exception as e: except Exception as e:
# import traceback import traceback
# traceback.print_exc() traceback.print_exc()
# raise RuntimeError("Could not connect to PGVector database server") from e raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -68,8 +68,7 @@ class FaissMemoryImpl(Memory):
self.config = config self.config = config
self.cache = {} self.cache = {}
async def initialize(self) -> None: async def initialize(self) -> None: ...
print("INIT meta-reference")
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...