From 665ab1f812efa6982acac1c0142b56dd16e94e48 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 21 Sep 2024 12:42:09 -0700 Subject: [PATCH] Revert "delete router from providers" This reverts commit d8fab77a4f71ec1051a17a9d75bdbfd398679c1c. --- llama_stack/distribution/server/server.py | 18 +++- llama_stack/providers/routers/__init__.py | 5 + .../providers/routers/memory/__init__.py | 17 ++++ .../providers/routers/memory/memory.py | 91 +++++++++++++++++++ 4 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 llama_stack/providers/routers/__init__.py create mode 100644 llama_stack/providers/routers/memory/__init__.py create mode 100644 llama_stack/providers/routers/memory/memory.py diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index aef007dce..57836caaa 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -365,8 +365,20 @@ async def resolve_impls( ) specs[api] = providers[item.provider_id] else: - raise ValueError( - f"Please define routing table in provider_routing_table of run config" + assert isinstance(item, list) + inner_specs = [] + for rt_entry in item: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) + + specs[api] = RouterProviderSpec( + api=api, + module=f"llama_stack.providers.routers.{api.value.lower()}", + api_dependencies=[], + inner_specs=inner_specs, ) sorted_specs = topological_sort(specs.values()) @@ -393,7 +405,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): if config.provider_routing_table is not None: impls, specs = asyncio.run(resolve_impls_with_routing(config)) else: - # keeping this for backwards compatibility + # keeping this for backwards compatibility,could impls, specs = asyncio.run(resolve_impls(config.provider_map)) if Api.telemetry in impls: diff --git a/llama_stack/providers/routers/__init__.py b/llama_stack/providers/routers/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/routers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/routers/memory/__init__.py b/llama_stack/providers/routers/memory/__init__.py new file mode 100644 index 000000000..d4dbbb1d4 --- /dev/null +++ b/llama_stack/providers/routers/memory/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, List, Tuple + +from llama_stack.distribution.datatypes import Api + + +async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): + from .memory import MemoryRouterImpl + + impl = MemoryRouterImpl(inner_impls, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/routers/memory/memory.py b/llama_stack/providers/routers/memory/memory.py new file mode 100644 index 000000000..b96cde626 --- /dev/null +++ b/llama_stack/providers/routers/memory/memory.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from llama_stack.distribution.datatypes import Api +from llama_stack.apis.memory import * # noqa: F403 + + +class MemoryRouterImpl(Memory): + """Routes to an provider based on the memory bank type""" + + def __init__( + self, + inner_impls: List[Tuple[str, Any]], + deps: List[Api], + ) -> None: + self.deps = deps + + bank_types = [v.value for v in MemoryBankType] + + self.providers = {} + for routing_key, provider_impl in inner_impls: + if routing_key not in bank_types: + raise ValueError( + f"Unknown routing key `{routing_key}` for memory bank type" + ) + self.providers[routing_key] = provider_impl + + self.bank_id_to_type = {} + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + for p in self.providers.values(): + await p.shutdown() + + def get_provider(self, bank_type): + if bank_type not in self.providers: + raise ValueError(f"Memory bank type {bank_type} not supported") + + return self.providers[bank_type] + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + provider = self.get_provider(config.type) + bank = await provider.create_memory_bank(name, config, url) + self.bank_id_to_type[bank.bank_id] = config.type + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.get_memory_bank(bank_id) + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.insert_documents(bank_id, documents, ttl_seconds) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.query_documents(bank_id, query, params)