Redo the { models, shields, memory_banks } typeset

This commit is contained in:
Ashwin Bharambe 2024-10-05 08:41:36 -07:00 committed by Ashwin Bharambe
parent 6b094b72d3
commit f3923e3f0b
15 changed files with 588 additions and 454 deletions

View file

@ -4,23 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from typing import Any, List
from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
RoutableObject,
RoutedProtocol,
ShieldsRoutingTable,
)
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
@ -29,7 +30,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](inner_impls, routing_table_config)
impl = api_to_tables[api.value](registry, impls_by_provider_id)
await impl.initialize()
return impl