donot use global state

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:19:54 -07:00 committed by Dinesh Yeduguru
parent 4b6367838f
commit 26a14c1d92
5 changed files with 19 additions and 11 deletions

View file

@ -7,6 +7,8 @@
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from .routing_tables import (
DatasetsRoutingTable,
MemoryBanksRoutingTable,
@ -20,6 +22,7 @@ async def get_routing_table_impl(
api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
dist_registry: distribution_store.Registry,
) -> Any:
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
@ -32,7 +35,7 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id)
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
await impl.initialize()
return impl

View file

@ -53,8 +53,10 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: Dict[str, RoutedProtocol],
dist_registry: distribution_store.Registry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
self.registry: Registry = {}
@ -171,7 +173,7 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
await distribution_store.REGISTRY.register(obj)
await self.dist_registry.register(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models):