mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
add support for provider update and unregister for memory banks
This commit is contained in:
parent
9b75e92852
commit
e8b699797c
11 changed files with 240 additions and 65 deletions
|
@ -154,5 +154,5 @@ class MemoryBanks(Protocol):
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/delete", method="POST")
|
@webmethod(route="/memory_banks/unregister", method="POST")
|
||||||
async def delete_memory_bank(self, memory_bank_id: str) -> None: ...
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
|
@ -82,7 +82,7 @@ class ModelsClient(Models):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return Model(**response.json())
|
return Model(**response.json())
|
||||||
|
|
||||||
async def delete_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.delete(
|
response = await client.delete(
|
||||||
f"{self.base_url}/models/delete",
|
f"{self.base_url}/models/delete",
|
||||||
|
|
|
@ -64,5 +64,5 @@ class Models(Protocol):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/delete", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
async def delete_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
|
@ -51,6 +51,24 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
||||||
|
async def update_object_with_provider(
|
||||||
|
obj: RoutableObject, p: Any
|
||||||
|
) -> Optional[RoutableObject]:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.memory:
|
||||||
|
return await p.update_memory_bank(obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Update not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.memory:
|
||||||
|
return await p.unregister_memory_bank(obj.identifier)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,14 +166,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
# TODO: delete from provider
|
await unregister_object_from_provider(
|
||||||
|
obj, self.impls_by_provider_id[obj.provider_id]
|
||||||
|
)
|
||||||
|
|
||||||
async def update_object(
|
async def update_object(
|
||||||
self, obj: RoutableObjectWithProvider
|
self, obj: RoutableObjectWithProvider
|
||||||
) -> RoutableObjectWithProvider:
|
) -> RoutableObjectWithProvider:
|
||||||
registered_obj = await register_object_with_provider(
|
registered_obj = await update_object_with_provider(
|
||||||
obj, self.impls_by_provider_id[obj.provider_id]
|
obj, self.impls_by_provider_id[obj.provider_id]
|
||||||
)
|
)
|
||||||
return await self.dist_registry.update(registered_obj or obj)
|
return await self.dist_registry.update(registered_obj or obj)
|
||||||
|
@ -253,11 +273,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
registered_model = await self.update_object(updated_model)
|
registered_model = await self.update_object(updated_model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
|
||||||
async def delete_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
existing_model = await self.get_model(model_id)
|
existing_model = await self.get_model(model_id)
|
||||||
if existing_model is None:
|
if existing_model is None:
|
||||||
raise ValueError(f"Model {model_id} not found")
|
raise ValueError(f"Model {model_id} not found")
|
||||||
await self.delete_object(existing_model)
|
await self.unregister_object(existing_model)
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
@ -358,11 +378,11 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
registered_bank = await self.update_object(updated_bank)
|
registered_bank = await self.update_object(updated_bank)
|
||||||
return registered_bank
|
return registered_bank
|
||||||
|
|
||||||
async def delete_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||||
if existing_bank is None:
|
if existing_bank is None:
|
||||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||||
await self.delete_object(existing_bank)
|
await self.unregister_object(existing_bank)
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
|
@ -55,6 +55,10 @@ class MemoryBanksProtocolPrivate(Protocol):
|
||||||
|
|
||||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
||||||
|
async def update_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class DatasetsProtocolPrivate(Protocol):
|
class DatasetsProtocolPrivate(Protocol):
|
||||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||||
|
@ -99,7 +103,6 @@ class RoutingTable(Protocol):
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
# TODO: this can now be inlined into RemoteProviderSpec
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_type: str = Field(
|
adapter_type: str = Field(
|
||||||
|
@ -172,10 +175,12 @@ class RemoteProviderConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: AdapterSpec = Field(
|
adapter: Optional[AdapterSpec] = Field(
|
||||||
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||||
API responses, specify the adapter here.
|
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||||
|
as being "Llama Stack compatible"
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -185,21 +190,38 @@ API responses, specify the adapter here.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module(self) -> str:
|
def module(self) -> str:
|
||||||
return self.adapter.module
|
if self.adapter:
|
||||||
|
return self.adapter.module
|
||||||
|
return "llama_stack.distribution.client"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> List[str]:
|
||||||
return self.adapter.pip_packages
|
if self.adapter:
|
||||||
|
return self.adapter.pip_packages
|
||||||
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_data_validator(self) -> Optional[str]:
|
def provider_data_validator(self) -> Optional[str]:
|
||||||
return self.adapter.provider_data_validator
|
if self.adapter:
|
||||||
|
return self.adapter.provider_data_validator
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||||
return RemoteProviderSpec(
|
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||||
api=api,
|
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
|
||||||
config_class=adapter.config_class,
|
# Can avoid this by using Pydantic computed_field
|
||||||
adapter=adapter,
|
def remote_provider_spec(
|
||||||
|
api: Api, adapter: Optional[AdapterSpec] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
|
config_class = (
|
||||||
|
adapter.config_class
|
||||||
|
if adapter and adapter.config_class
|
||||||
|
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||||
|
)
|
||||||
|
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||||
|
|
||||||
|
return RemoteProviderSpec(
|
||||||
|
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
@ -37,10 +39,57 @@ class FaissIndex(EmbeddingIndex):
|
||||||
id_by_index: Dict[int, str]
|
id_by_index: Dict[int, str]
|
||||||
chunk_by_index: Dict[int, str]
|
chunk_by_index: Dict[int, str]
|
||||||
|
|
||||||
def __init__(self, dimension: int):
|
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
self.id_by_index = {}
|
self.id_by_index = {}
|
||||||
self.chunk_by_index = {}
|
self.chunk_by_index = {}
|
||||||
|
self.kvstore = kvstore
|
||||||
|
self.bank_id = bank_id
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
if not self.kvstore or not self.bank_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load existing index data from kvstore
|
||||||
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||||
|
stored_data = await self.kvstore.get(index_key)
|
||||||
|
|
||||||
|
if stored_data:
|
||||||
|
data = json.loads(stored_data)
|
||||||
|
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
|
||||||
|
self.chunk_by_index = {
|
||||||
|
int(k): Chunk.model_validate_json(v)
|
||||||
|
for k, v in data["chunk_by_index"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load FAISS index
|
||||||
|
index_bytes = base64.b64decode(data["faiss_index"])
|
||||||
|
self.index = faiss.deserialize_index(index_bytes)
|
||||||
|
|
||||||
|
async def _save_index(self):
|
||||||
|
if not self.kvstore or not self.bank_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Serialize FAISS index
|
||||||
|
index_bytes = faiss.serialize_index(self.index)
|
||||||
|
|
||||||
|
# Prepare data for storage
|
||||||
|
data = {
|
||||||
|
"id_by_index": self.id_by_index,
|
||||||
|
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
||||||
|
"faiss_index": base64.b64encode(index_bytes).decode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store in kvstore
|
||||||
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||||
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
if not self.kvstore or not self.bank_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
||||||
|
|
||||||
@tracing.span(name="add_chunks")
|
@tracing.span(name="add_chunks")
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
@ -51,6 +100,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
self.index.add(np.array(embeddings).astype(np.float32))
|
self.index.add(np.array(embeddings).astype(np.float32))
|
||||||
|
|
||||||
|
# Save updated index
|
||||||
|
await self._save_index()
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, embedding: NDArray, k: int, score_threshold: float
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
|
@ -85,7 +137,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
|
||||||
|
@ -110,13 +162,28 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
bank=memory_bank,
|
||||||
|
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
del self.cache[memory_bank_id]
|
||||||
|
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
||||||
|
|
||||||
|
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||||
|
# Not possible to update the index in place, so we delete and recreate
|
||||||
|
await self.cache[memory_bank.identifier].index.delete()
|
||||||
|
|
||||||
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
|
bank=memory_bank,
|
||||||
|
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
|
||||||
|
)
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -67,6 +67,9 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
await self.client.delete_collection(self.collection.name)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
|
@ -134,6 +137,14 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
del self.cache[memory_bank_id]
|
||||||
|
|
||||||
|
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||||
|
await self.unregister_memory_bank(memory_bank.identifier)
|
||||||
|
await self.register_memory_bank(memory_bank)
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -112,6 +112,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
|
@ -177,6 +180,14 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
del self.cache[memory_bank_id]
|
||||||
|
|
||||||
|
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||||
|
await self.unregister_memory_bank(memory_bank.identifier)
|
||||||
|
await self.register_memory_bank(memory_bank)
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
banks = load_models(self.cursor, VectorMemoryBank)
|
banks = load_models(self.cursor, VectorMemoryBank)
|
||||||
for bank in banks:
|
for bank in banks:
|
||||||
|
|
|
@ -54,4 +54,4 @@ class TestModelRegistration:
|
||||||
assert updated_model.provider_resource_id != old_model.provider_resource_id
|
assert updated_model.provider_resource_id != old_model.provider_resource_id
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
await models_impl.delete_model(model_id=model_id)
|
await models_impl.unregister_model(model_id=model_id)
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
@ -43,9 +45,10 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||||
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id="test_bank",
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
|
@ -57,43 +60,70 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_banks_list(self, memory_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
|
# Register a test bank
|
||||||
|
registered_bank = await register_memory_bank(banks_impl)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify our bank shows up in list
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert any(
|
||||||
|
bank.memory_bank_id == registered_bank.memory_bank_id
|
||||||
|
for bank in response
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id)
|
||||||
|
|
||||||
|
# Verify our bank was removed
|
||||||
response = await banks_impl.list_memory_banks()
|
response = await banks_impl.list_memory_banks()
|
||||||
assert isinstance(response, list)
|
assert all(
|
||||||
assert len(response) == 0
|
bank.memory_bank_id != registered_bank.memory_bank_id for bank in response
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack):
|
async def test_banks_register(self, memory_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
await banks_impl.register_memory_bank(
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
memory_bank_id="test_bank_no_provider",
|
|
||||||
params=VectorMemoryBankParams(
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
response = await banks_impl.list_memory_banks()
|
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 1
|
|
||||||
|
|
||||||
# register same memory bank with same id again will fail
|
try:
|
||||||
await banks_impl.register_memory_bank(
|
# Register initial bank
|
||||||
memory_bank_id="test_bank_no_provider",
|
await banks_impl.register_memory_bank(
|
||||||
params=VectorMemoryBankParams(
|
memory_bank_id=bank_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
params=VectorMemoryBankParams(
|
||||||
chunk_size_in_tokens=512,
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
overlap_size_in_tokens=64,
|
chunk_size_in_tokens=512,
|
||||||
),
|
overlap_size_in_tokens=64,
|
||||||
)
|
),
|
||||||
response = await banks_impl.list_memory_banks()
|
)
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 1
|
# Verify our bank exists
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert any(bank.memory_bank_id == bank_id for bank in response)
|
||||||
|
|
||||||
|
# Try registering same bank again
|
||||||
|
await banks_impl.register_memory_bank(
|
||||||
|
memory_bank_id=bank_id,
|
||||||
|
params=VectorMemoryBankParams(
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify still only one instance of our bank
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert (
|
||||||
|
len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
await banks_impl.unregister_memory_bank(bank_id)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(self, memory_stack, sample_documents):
|
async def test_query_documents(self, memory_stack, sample_documents):
|
||||||
|
@ -102,17 +132,23 @@ class TestMemory:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl)
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents(
|
||||||
|
registered_bank.memory_bank_id, sample_documents
|
||||||
|
)
|
||||||
|
|
||||||
query1 = "programming language"
|
query1 = "programming language"
|
||||||
response1 = await memory_impl.query_documents("test_bank", query1)
|
response1 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query1
|
||||||
|
)
|
||||||
assert_valid_response(response1)
|
assert_valid_response(response1)
|
||||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||||
|
|
||||||
# Test case 3: Query with semantic similarity
|
# Test case 3: Query with semantic similarity
|
||||||
query3 = "AI and brain-inspired computing"
|
query3 = "AI and brain-inspired computing"
|
||||||
response3 = await memory_impl.query_documents("test_bank", query3)
|
response3 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query3
|
||||||
|
)
|
||||||
assert_valid_response(response3)
|
assert_valid_response(response3)
|
||||||
assert any(
|
assert any(
|
||||||
"neural networks" in chunk.content.lower() for chunk in response3.chunks
|
"neural networks" in chunk.content.lower() for chunk in response3.chunks
|
||||||
|
@ -121,14 +157,18 @@ class TestMemory:
|
||||||
# Test case 4: Query with limit on number of results
|
# Test case 4: Query with limit on number of results
|
||||||
query4 = "computer"
|
query4 = "computer"
|
||||||
params4 = {"max_chunks": 2}
|
params4 = {"max_chunks": 2}
|
||||||
response4 = await memory_impl.query_documents("test_bank", query4, params4)
|
response4 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query4, params4
|
||||||
|
)
|
||||||
assert_valid_response(response4)
|
assert_valid_response(response4)
|
||||||
assert len(response4.chunks) <= 2
|
assert len(response4.chunks) <= 2
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.2}
|
params5 = {"score_threshold": 0.2}
|
||||||
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
response5 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query5, params5
|
||||||
|
)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
assert all(score >= 0.2 for score in response5.scores)
|
assert all(score >= 0.2 for score in response5.scores)
|
||||||
|
|
|
@ -145,6 +145,10 @@ class EmbeddingIndex(ABC):
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BankWithIndex:
|
class BankWithIndex:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue