mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
parent
ee77431b64
commit
665ab1f812
4 changed files with 128 additions and 3 deletions
|
@ -365,8 +365,20 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
specs[api] = providers[item.provider_id]
|
specs[api] = providers[item.provider_id]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
assert isinstance(item, list)
|
||||||
f"Please define routing table in provider_routing_table of run config"
|
inner_specs = []
|
||||||
|
for rt_entry in item:
|
||||||
|
if rt_entry.provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
|
||||||
|
specs[api] = RouterProviderSpec(
|
||||||
|
api=api,
|
||||||
|
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||||
|
api_dependencies=[],
|
||||||
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
@ -393,7 +405,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
if config.provider_routing_table is not None:
|
if config.provider_routing_table is not None:
|
||||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||||
else:
|
else:
|
||||||
# keeping this for backwards compatibility
|
# keeping this for backwards compatibility,could
|
||||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
|
|
5
llama_stack/providers/routers/__init__.py
Normal file
5
llama_stack/providers/routers/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
17
llama_stack/providers/routers/memory/__init__.py
Normal file
17
llama_stack/providers/routers/memory/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
||||||
|
from .memory import MemoryRouterImpl
|
||||||
|
|
||||||
|
impl = MemoryRouterImpl(inner_impls, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
91
llama_stack/providers/routers/memory/memory.py
Normal file
91
llama_stack/providers/routers/memory/memory.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRouterImpl(Memory):
|
||||||
|
"""Routes to an provider based on the memory bank type"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_impls: List[Tuple[str, Any]],
|
||||||
|
deps: List[Api],
|
||||||
|
) -> None:
|
||||||
|
self.deps = deps
|
||||||
|
|
||||||
|
bank_types = [v.value for v in MemoryBankType]
|
||||||
|
|
||||||
|
self.providers = {}
|
||||||
|
for routing_key, provider_impl in inner_impls:
|
||||||
|
if routing_key not in bank_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown routing key `{routing_key}` for memory bank type"
|
||||||
|
)
|
||||||
|
self.providers[routing_key] = provider_impl
|
||||||
|
|
||||||
|
self.bank_id_to_type = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for p in self.providers.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
|
def get_provider(self, bank_type):
|
||||||
|
if bank_type not in self.providers:
|
||||||
|
raise ValueError(f"Memory bank type {bank_type} not supported")
|
||||||
|
|
||||||
|
return self.providers[bank_type]
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
provider = self.get_provider(config.type)
|
||||||
|
bank = await provider.create_memory_bank(name, config, url)
|
||||||
|
self.bank_id_to_type[bank.bank_id] = config.type
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_type = self.bank_id_to_type.get(bank_id)
|
||||||
|
if not bank_type:
|
||||||
|
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||||
|
|
||||||
|
provider = self.get_provider(bank_type)
|
||||||
|
return await provider.get_memory_bank(bank_id)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
bank_type = self.bank_id_to_type.get(bank_id)
|
||||||
|
if not bank_type:
|
||||||
|
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||||
|
|
||||||
|
provider = self.get_provider(bank_type)
|
||||||
|
return await provider.insert_documents(bank_id, documents, ttl_seconds)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
bank_type = self.bank_id_to_type.get(bank_id)
|
||||||
|
if not bank_type:
|
||||||
|
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||||
|
|
||||||
|
provider = self.get_provider(bank_type)
|
||||||
|
return await provider.query_documents(bank_id, query, params)
|
Loading…
Add table
Add a link
Reference in a new issue