inference registry updates

This commit is contained in:
Ashwin Bharambe 2024-10-05 22:25:48 -07:00 committed by Ashwin Bharambe
parent 4215cc9331
commit 59302a86df
12 changed files with 570 additions and 535 deletions

View file

@ -15,6 +15,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
@ -54,7 +56,7 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
return None
def register_object(self, obj: RoutableObject) -> None:
async def register_object_common(self, obj: RoutableObject) -> None:
if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered")
@ -79,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.register_object(model)
await self.register_object_common(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
@ -93,7 +95,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object(shield)
await self.register_object_common(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
@ -107,4 +109,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object(bank)
await self.register_object_common(bank)