unregister for memory banks and remove update API (#458)

The semantics of an Update on resources is very tricky to reason about
especially for memory banks and models. The best way to go forward here
is for the user to unregister and register a new resource. We don't have
a compelling reason to support update APIs.


Tests:
pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m
"chroma" --env CHROMA_HOST=localhost --env CHROMA_PORT=8000

pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m
"pgvector" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env
PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0

$CONDA_PREFIX/bin/pytest -v -s -m "ollama"
llama_stack/providers/tests/inference/test_model_registration.py

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-14 17:12:11 -08:00 committed by GitHub
parent 2eab3b7ed9
commit 0850ad656a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 286 additions and 250 deletions

View file

@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 12:51:12.176325"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
},
"servers": [
{
@ -429,39 +429,6 @@
}
}
},
"/models/delete": {
"post": {
"responses": {
"200": {
"description": "OK"
}
},
"tags": [
"Models"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/DeleteModelRequest"
}
}
},
"required": true
}
}
},
"/inference/embeddings": {
"post": {
"responses": {
@ -2259,18 +2226,44 @@
}
}
},
"/models/update": {
"/memory_banks/unregister": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Model"
}
"description": "OK"
}
},
"tags": [
"MemoryBanks"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UnregisterMemoryBankRequest"
}
}
},
"required": true
}
}
},
"/models/unregister": {
"post": {
"responses": {
"200": {
"description": "OK"
}
},
"tags": [
@ -2291,7 +2284,7 @@
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UpdateModelRequest"
"$ref": "#/components/schemas/UnregisterModelRequest"
}
}
},
@ -4622,18 +4615,6 @@
"session_id"
]
},
"DeleteModelRequest": {
"type": "object",
"properties": {
"model_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"model_id"
]
},
"EmbeddingsRequest": {
"type": "object",
"properties": {
@ -7912,42 +7893,23 @@
],
"title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
},
"UpdateModelRequest": {
"UnregisterMemoryBankRequest": {
"type": "object",
"properties": {
"memory_bank_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"memory_bank_id"
]
},
"UnregisterModelRequest": {
"type": "object",
"properties": {
"model_id": {
"type": "string"
},
"provider_model_id": {
"type": "string"
},
"provider_id": {
"type": "string"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
@ -8132,10 +8094,6 @@
"name": "DeleteAgentsSessionRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DeleteAgentsSessionRequest\" />"
},
{
"name": "DeleteModelRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DeleteModelRequest\" />"
},
{
"name": "DoraFinetuningConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DoraFinetuningConfig\" />"
@ -8563,12 +8521,16 @@
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/URL\" />"
},
{
"name": "UnstructuredLogEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
"name": "UnregisterMemoryBankRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterMemoryBankRequest\" />"
},
{
"name": "UpdateModelRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UpdateModelRequest\" />"
"name": "UnregisterModelRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
},
{
"name": "UnstructuredLogEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
},
{
"name": "UserMessage",
@ -8657,7 +8619,6 @@
"Dataset",
"DeleteAgentsRequest",
"DeleteAgentsSessionRequest",
"DeleteModelRequest",
"DoraFinetuningConfig",
"EmbeddingsRequest",
"EmbeddingsResponse",
@ -8754,8 +8715,9 @@
"TrainingConfig",
"Turn",
"URL",
"UnregisterMemoryBankRequest",
"UnregisterModelRequest",
"UnstructuredLogEvent",
"UpdateModelRequest",
"UserMessage",
"VectorMemoryBank",
"VectorMemoryBankParams",

View file

@ -867,14 +867,6 @@ components:
- agent_id
- session_id
type: object
DeleteModelRequest:
additionalProperties: false
properties:
model_id:
type: string
required:
- model_id
type: object
DoraFinetuningConfig:
additionalProperties: false
properties:
@ -3244,6 +3236,22 @@ components:
format: uri
pattern: ^(https?://|file://|data:)
type: string
UnregisterMemoryBankRequest:
additionalProperties: false
properties:
memory_bank_id:
type: string
required:
- memory_bank_id
type: object
UnregisterModelRequest:
additionalProperties: false
properties:
model_id:
type: string
required:
- model_id
type: object
UnstructuredLogEvent:
additionalProperties: false
properties:
@ -3280,28 +3288,6 @@ components:
- message
- severity
type: object
UpdateModelRequest:
additionalProperties: false
properties:
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
model_id:
type: string
provider_id:
type: string
provider_model_id:
type: string
required:
- model_id
type: object
UserMessage:
additionalProperties: false
properties:
@ -3414,7 +3400,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-11-14 12:51:12.176325"
\ draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -4216,7 +4202,7 @@ paths:
responses: {}
tags:
- MemoryBanks
/models/delete:
/memory_banks/unregister:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
@ -4230,13 +4216,13 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/DeleteModelRequest'
$ref: '#/components/schemas/UnregisterMemoryBankRequest'
required: true
responses:
'200':
description: OK
tags:
- Models
- MemoryBanks
/models/get:
get:
parameters:
@ -4307,7 +4293,7 @@ paths:
description: OK
tags:
- Models
/models/update:
/models/unregister:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
@ -4321,14 +4307,10 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/UpdateModelRequest'
$ref: '#/components/schemas/UnregisterModelRequest'
required: true
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/Model'
description: OK
tags:
- Models
@ -4960,9 +4942,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsSessionRequest"
/>
name: DeleteAgentsSessionRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteModelRequest"
/>
name: DeleteModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/DoraFinetuningConfig"
/>
name: DoraFinetuningConfig
@ -5257,12 +5236,15 @@ tags:
name: Turn
- description: <SchemaDefinition schemaRef="#/components/schemas/URL" />
name: URL
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterMemoryBankRequest"
/>
name: UnregisterMemoryBankRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
/>
name: UnregisterModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
/>
name: UnstructuredLogEvent
- description: <SchemaDefinition schemaRef="#/components/schemas/UpdateModelRequest"
/>
name: UpdateModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
name: UserMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBank"
@ -5338,7 +5320,6 @@ x-tagGroups:
- Dataset
- DeleteAgentsRequest
- DeleteAgentsSessionRequest
- DeleteModelRequest
- DoraFinetuningConfig
- EmbeddingsRequest
- EmbeddingsResponse
@ -5435,8 +5416,9 @@ x-tagGroups:
- TrainingConfig
- Turn
- URL
- UnregisterMemoryBankRequest
- UnregisterModelRequest
- UnstructuredLogEvent
- UpdateModelRequest
- UserMessage
- VectorMemoryBank
- VectorMemoryBankParams

View file

@ -144,3 +144,6 @@ class MemoryBanks(Protocol):
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory_banks/unregister", method="POST")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -7,7 +7,7 @@
import asyncio
import json
from typing import Any, Dict, List, Optional
from typing import List, Optional
import fire
import httpx
@ -61,28 +61,7 @@ class ModelsClient(Models):
return None
return Model(**j)
async def update_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model:
async with httpx.AsyncClient() as client:
response = await client.put(
f"{self.base_url}/models/update",
json={
"model_id": model_id,
"provider_model_id": provider_model_id,
"provider_id": provider_id,
"metadata": metadata,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return Model(**response.json())
async def delete_model(self, model_id: str) -> None:
async def unregister_model(self, model_id: str) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.base_url}/models/delete",

View file

@ -55,14 +55,5 @@ class Models(Protocol):
metadata: Optional[Dict[str, Any]] = None,
) -> Model: ...
@webmethod(route="/models/update", method="POST")
async def update_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model: ...
@webmethod(route="/models/delete", method="POST")
async def delete_model(self, model_id: str) -> None: ...
@webmethod(route="/models/unregister", method="POST")
async def unregister_model(self, model_id: str) -> None: ...

View file

@ -51,6 +51,16 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = Dict[str, List[RoutableObjectWithProvider]]
@ -148,17 +158,11 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
# TODO: delete from provider
async def update_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
registered_obj = await register_object_with_provider(
await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
return await self.dist_registry.update(registered_obj)
async def register_object(
self, obj: RoutableObjectWithProvider
@ -232,32 +236,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
registered_model = await self.register_object(model)
return registered_model
async def update_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model:
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
updated_model = Model(
identifier=model_id,
provider_resource_id=provider_model_id
or existing_model.provider_resource_id,
provider_id=provider_id or existing_model.provider_id,
metadata=metadata or existing_model.metadata,
)
registered_model = await self.update_object(updated_model)
return registered_model
async def delete_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.delete_object(existing_model)
await self.unregister_object(existing_model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
@ -333,6 +316,12 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
await self.register_object(memory_bank)
return memory_bank
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
await self.unregister_object(existing_bank)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]:

View file

@ -45,6 +45,8 @@ class Api(Enum):
class ModelsProtocolPrivate(Protocol):
async def register_model(self, model: Model) -> None: ...
async def unregister_model(self, model_id: str) -> None: ...
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
@ -55,6 +57,8 @@ class MemoryBanksProtocolPrivate(Protocol):
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...

View file

@ -71,6 +71,9 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,

View file

@ -108,6 +108,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
return VLLMSamplingParams(**kwargs)
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import logging
from typing import Any, Dict, List, Optional
@ -37,10 +39,52 @@ class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int):
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
self.kvstore = kvstore
self.bank_id = bank_id
self.initialize()
async def initialize(self) -> None:
if not self.kvstore:
return
index_key = f"faiss_index:v1::{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
if stored_data:
data = json.loads(stored_data)
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
self.chunk_by_index = {
int(k): Chunk.model_validate_json(v)
for k, v in data["chunk_by_index"].items()
}
index_bytes = base64.b64decode(data["faiss_index"])
self.index = faiss.deserialize_index(index_bytes)
async def _save_index(self):
if not self.kvstore or not self.bank_id:
return
index_bytes = faiss.serialize_index(self.index)
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(index_bytes).decode(),
}
index_key = f"faiss_index:v1::{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self):
if not self.kvstore or not self.bank_id:
return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
@ -51,6 +95,9 @@ class FaissIndex(EmbeddingIndex):
self.index.add(np.array(embeddings).astype(np.float32))
# Save updated index
await self._save_index()
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
@ -85,7 +132,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
)
self.cache[bank.identifier] = index
@ -110,13 +157,19 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
# Store in cache
index = BankWithIndex(
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
async def insert_documents(
self,
bank_id: str,

View file

@ -93,6 +93,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,

View file

@ -69,6 +69,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model: str,

View file

@ -58,6 +58,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,

View file

@ -67,6 +67,9 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self):
await self.client.delete_collection(self.collection.name)
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
@ -134,6 +137,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def insert_documents(
self,
bank_id: str,

View file

@ -112,6 +112,9 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None:
@ -177,6 +180,10 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def list_memory_banks(self) -> List[MemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks:

View file

@ -54,4 +54,4 @@ class TestModelRegistration:
assert updated_model.provider_resource_id != old_model.provider_resource_id
# Cleanup
await models_impl.delete_model(model_id=model_id)
await models_impl.unregister_model(model_id=model_id)

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import pytest
from llama_stack.apis.memory import * # noqa: F403
@ -43,9 +45,10 @@ def sample_documents():
]
async def register_memory_bank(banks_impl: MemoryBanks):
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank(
memory_bank_id="test_bank",
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
@ -57,43 +60,70 @@ async def register_memory_bank(banks_impl: MemoryBanks):
class TestMemory:
@pytest.mark.asyncio
async def test_banks_list(self, memory_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
# Register a test bank
registered_bank = await register_memory_bank(banks_impl)
try:
# Verify our bank shows up in list
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert any(
bank.memory_bank_id == registered_bank.memory_bank_id
for bank in response
)
finally:
# Clean up
await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id)
# Verify our bank was removed
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 0
assert all(
bank.memory_bank_id != registered_bank.memory_bank_id for bank in response
)
@pytest.mark.asyncio
async def test_banks_register(self, memory_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
bank_id = f"test_bank_{uuid.uuid4().hex}"
# register same memory bank with same id again will fail
await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
try:
# Register initial bank
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
# Verify our bank exists
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert any(bank.memory_bank_id == bank_id for bank in response)
# Try registering same bank again
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
# Verify still only one instance of our bank
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert (
len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1
)
finally:
# Clean up
await banks_impl.unregister_memory_bank(bank_id)
@pytest.mark.asyncio
async def test_query_documents(self, memory_stack, sample_documents):
@ -102,17 +132,23 @@ class TestMemory:
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl)
await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents
)
query1 = "programming language"
response1 = await memory_impl.query_documents("test_bank", query1)
response1 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query1
)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing"
response3 = await memory_impl.query_documents("test_bank", query3)
response3 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query3
)
assert_valid_response(response3)
assert any(
"neural networks" in chunk.content.lower() for chunk in response3.chunks
@ -121,14 +157,18 @@ class TestMemory:
# Test case 4: Query with limit on number of results
query4 = "computer"
params4 = {"max_chunks": 2}
response4 = await memory_impl.query_documents("test_bank", query4, params4)
response4 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query4, params4
)
assert_valid_response(response4)
assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2}
response5 = await memory_impl.query_documents("test_bank", query5, params5)
response5 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query5, params5
)
assert_valid_response(response5)
print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores)

View file

@ -145,6 +145,10 @@ class EmbeddingIndex(ABC):
) -> QueryDocumentsResponse:
raise NotImplementedError()
@abstractmethod
async def delete(self):
raise NotImplementedError()
@dataclass
class BankWithIndex: