persist registered objects with distribution (#354)

* persist registered objects with distribution

* linter fixes

* comment

* use annotate and field discriminator

* workign tests

* donot use global state

* precommit failures fixed

* add back Any

* fix imports

* remove unnecessary changes in ollama

* precommit failures fixed

* make kvstore configurable for dist and rename registry

* add comment about registry list return

* fix linter errors

* use registry to hydrate

* remove debug print

* linter fixes

* remove kvstore.db

* rename distribution_registry_store

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-04 17:25:06 -08:00 committed by GitHub
parent c9bf1d7d0b
commit 663883cc29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 401 additions and 46 deletions

View file

@ -13,6 +13,7 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
@ -46,25 +47,23 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
Registry = Dict[str, List[RoutableObjectWithProvider]]
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: Dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
self.registry: Registry = {}
# Initialize the registry if not already done
await self.dist_registry.initialize()
def add_objects(
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs:
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
if cls is None:
obj.provider_id = provider_id
else:
@ -74,34 +73,35 @@ class CommonRoutingTableImpl(RoutingTable):
obj.provider_id = provider_id
else:
obj = cls(**obj.model_dump(), provider_id=provider_id)
self.registry[obj.identifier].append(obj)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
models = await p.list_models()
add_objects(models, pid, ModelDefWithProvider)
await add_objects(models, pid, ModelDefWithProvider)
elif api == Api.safety:
p.shield_store = self
shields = await p.list_shields()
add_objects(shields, pid, ShieldDefWithProvider)
await add_objects(shields, pid, ShieldDefWithProvider)
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
add_objects(memory_banks, pid, None)
await add_objects(memory_banks, pid, None)
elif api == Api.datasetio:
p.dataset_store = self
datasets = await p.list_datasets()
add_objects(datasets, pid, DatasetDefWithProvider)
await add_objects(datasets, pid, DatasetDefWithProvider)
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -124,39 +124,44 @@ class CommonRoutingTableImpl(RoutingTable):
else:
raise ValueError("Unknown routing table type")
if routing_key not in self.registry:
# Get objects from disk registry
objects = self.dist_registry.get_cached(routing_key)
if not objects:
apiname, objname = apiname_object()
raise ValueError(
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
)
objs = self.registry[routing_key]
for obj in objs:
for obj in objects:
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
def get_object_by_identifier(
async def get_object_by_identifier(
self, identifier: str
) -> Optional[RoutableObjectWithProvider]:
objs = self.registry.get(identifier, [])
if not objs:
# Get from disk registry
objects = await self.dist_registry.get(identifier)
if not objects:
return None
# kind of ill-defined behavior here, but we'll just return the first one
return objs[0]
return objects[0]
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id or not obj.provider_id:
# Get existing objects from registry
existing_objects = await self.dist_registry.get(obj.identifier)
# Check for existing registration
for existing_obj in existing_objects:
if existing_obj.provider_id == obj.provider_id or not obj.provider_id:
print(
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
)
return
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -166,12 +171,7 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
# TODO: persist this to a store
await self.dist_registry.register(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models):