migrate memory banks to Resource and new registration (#411)

* migrate memory banks to Resource and new registration

* address feedback

* address feedback

* fix tests

* pgvector fix

* pgvector fix v2

* remove auto discovery

* change register signature to make params required

* update client

* client fix

* use annotated union to parse

* remove base MemoryBank inheritence

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:10:44 -08:00 committed by GitHub
parent 6b9850e11b
commit 38cce97597
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 240 additions and 129 deletions

View file

@ -33,7 +33,7 @@ RoutingKey = Union[str, List[str]]
RoutableObject = Union[
Model,
Shield,
MemoryBankDef,
MemoryBank,
DatasetDef,
ScoringFnDef,
]
@ -43,7 +43,7 @@ RoutableObjectWithProvider = Annotated[
Union[
Model,
Shield,
MemoryBankDefWithProvider,
MemoryBank,
DatasetDefWithProvider,
ScoringFnDefWithProvider,
],

View file

@ -7,8 +7,8 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.distribution.datatypes import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
@ -32,8 +32,19 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(memory_bank)
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
) -> None:
await self.routing_table.register_memory_bank(
memory_bank_id,
params,
provider_id,
provider_memorybank_id,
)
async def insert_documents(
self,

View file

@ -6,6 +6,8 @@
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -89,8 +91,6 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
await add_objects(memory_banks, pid, None)
elif api == Api.datasetio:
p.dataset_store = self
@ -188,12 +188,6 @@ class CommonRoutingTableImpl(RoutingTable):
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
async def get_all_with_types(
self, types: List[str]
) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type in types]
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[Model]:
@ -233,7 +227,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[Shield]:
return await self.get_all_with_type("shield")
return await self.get_all_with_type(ResourceType.shield.value)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier(identifier)
@ -270,25 +264,41 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
return await self.get_all_with_types(
[
MemoryBankType.vector.value,
MemoryBankType.keyvalue.value,
MemoryBankType.keyword.value,
MemoryBankType.graph.value,
]
)
async def list_memory_banks(self) -> List[MemoryBank]:
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]:
return await self.get_object_by_identifier(identifier)
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier(memory_bank_id)
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None:
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
) -> MemoryBank:
if provider_memorybank_id is None:
provider_memorybank_id = memory_bank_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
memory_bank = parse_obj_as(
MemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memorybank_id,
**params.model_dump(),
},
)
await self.register_object(memory_bank)
return memory_bank
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -10,7 +10,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.store import * # noqa F403
from llama_stack.apis.inference import Model
from llama_stack.apis.memory_banks import VectorMemoryBankDef
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.datatypes import * # noqa F403
@ -39,7 +39,7 @@ async def cached_registry(config):
@pytest.fixture
def sample_bank():
return VectorMemoryBankDef(
return VectorMemoryBank(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
@ -113,7 +113,7 @@ async def test_cached_registry_updates(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
new_bank = VectorMemoryBankDef(
new_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
@ -144,7 +144,7 @@ async def test_duplicate_provider_registration(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
original_bank = VectorMemoryBankDef(
original_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
@ -153,7 +153,7 @@ async def test_duplicate_provider_registration(config):
)
await cached_registry.register(original_bank)
duplicate_bank = VectorMemoryBankDef(
duplicate_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="different-model",
chunk_size_in_tokens=128,