mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Introduce model_store, shield_store, memory_bank_store
This commit is contained in:
parent
e45a417543
commit
91e0063593
19 changed files with 172 additions and 297 deletions
|
|
@ -15,6 +15,20 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
|
||||
|
||||
# TODO: this routing table maintains state in memory purely. We need to
|
||||
# add persistence to it when we add dynamic registration of objects.
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
|
|
@ -32,6 +46,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.registry = registry
|
||||
|
||||
for p in self.impls_by_provider_id.values():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
|
||||
self.routing_key_to_object = {}
|
||||
for obj in self.registry:
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
|
|
@ -39,7 +62,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
async def initialize(self) -> None:
|
||||
for obj in self.registry:
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await self.register_object(obj, p)
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
|
@ -57,7 +80,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return obj
|
||||
return None
|
||||
|
||||
async def register_object_common(self, obj: RoutableObject) -> None:
|
||||
async def register_object(self, obj: RoutableObject) -> Any:
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
raise ValueError(f"Object `{obj.identifier}` already registered")
|
||||
|
||||
|
|
@ -65,16 +88,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
await p.register_object(obj)
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
self.routing_key_to_object[obj.identifier] = obj
|
||||
self.registry.append(obj)
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def register_object(self, obj: ModelDef, p: Inference) -> None:
|
||||
await p.register_model(obj)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.registry
|
||||
|
||||
|
|
@ -82,13 +102,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.register_object_common(model)
|
||||
await self.register_object(model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def register_object(self, obj: ShieldDef, p: Safety) -> None:
|
||||
await p.register_shield(obj)
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.registry
|
||||
|
||||
|
|
@ -96,13 +113,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
return self.get_object_by_identifier(shield_type)
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.register_object_common(shield)
|
||||
await self.register_object(shield)
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def register_object(self, obj: MemoryBankDef, p: Memory) -> None:
|
||||
await p.register_memory_bank(obj)
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return self.registry
|
||||
|
||||
|
|
@ -110,4 +124,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||
await self.register_object_common(bank)
|
||||
await self.register_object(bank)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue