migrate memory banks to Resource and new registration

This commit is contained in:
Dinesh Yeduguru 2024-11-08 15:45:26 -08:00
parent b4416b72fd
commit c82f13bf9e
16 changed files with 178 additions and 104 deletions

View file

@ -33,7 +33,7 @@ RoutingKey = Union[str, List[str]]
RoutableObject = Union[
Model,
Shield,
MemoryBankDef,
MemoryBank,
DatasetDef,
ScoringFnDef,
]
@ -43,7 +43,7 @@ RoutableObjectWithProvider = Annotated[
Union[
Model,
Shield,
MemoryBankDefWithProvider,
MemoryBank,
DatasetDefWithProvider,
ScoringFnDefWithProvider,
],

View file

@ -32,8 +32,11 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(memory_bank)
async def register_memory_bank(
self,
request: RegistrationRequest,
) -> None:
await self.routing_table.register_memory_bank(request)
async def insert_documents(
self,

View file

@ -188,12 +188,6 @@ class CommonRoutingTableImpl(RoutingTable):
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
async def get_all_with_types(
self, types: List[str]
) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type in types]
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[Model]:
@ -233,7 +227,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[Shield]:
return await self.get_all_with_type("shield")
return await self.get_all_with_type(ResourceType.shield.value)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier(identifier)
@ -270,25 +264,29 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
return await self.get_all_with_types(
[
MemoryBankType.vector.value,
MemoryBankType.keyvalue.value,
MemoryBankType.keyword.value,
MemoryBankType.graph.value,
]
)
async def list_memory_banks(self) -> List[MemoryBank]:
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]:
return await self.get_object_by_identifier(identifier)
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier(memory_bank_id)
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None:
self,
request: RegistrationRequest,
) -> MemoryBank:
if request.provider_resource_id is None:
request.provider_resource_id = request.memory_bank_id
if request.provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
request.provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
memory_bank = registration_request_to_memory_bank(request)
await self.register_object(memory_bank)
return memory_bank
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -10,7 +10,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.store import * # noqa F403
from llama_stack.apis.inference import Model
from llama_stack.apis.memory_banks import VectorMemoryBankDef
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.datatypes import * # noqa F403
@ -39,7 +39,7 @@ async def cached_registry(config):
@pytest.fixture
def sample_bank():
return VectorMemoryBankDef(
return VectorMemoryBank(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
@ -113,7 +113,7 @@ async def test_cached_registry_updates(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
new_bank = VectorMemoryBankDef(
new_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
@ -144,7 +144,7 @@ async def test_duplicate_provider_registration(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
original_bank = VectorMemoryBankDef(
original_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
@ -153,7 +153,7 @@ async def test_duplicate_provider_registration(config):
)
await cached_registry.register(original_bank)
duplicate_bank = VectorMemoryBankDef(
duplicate_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="different-model",
chunk_size_in_tokens=128,