From e8b699797c336cf8729f08660a949285997568ff Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 14 Nov 2024 16:08:24 -0800 Subject: [PATCH] add support for provider update and unregister for memory banks --- llama_stack/apis/memory_banks/memory_banks.py | 4 +- llama_stack/apis/models/client.py | 2 +- llama_stack/apis/models/models.py | 4 +- .../distribution/routers/routing_tables.py | 34 ++++-- llama_stack/providers/datatypes.py | 46 +++++-- .../providers/inline/memory/faiss/faiss.py | 73 ++++++++++- .../providers/remote/memory/chroma/chroma.py | 11 ++ .../remote/memory/pgvector/pgvector.py | 11 ++ .../inference/test_model_registration.py | 2 +- .../providers/tests/memory/test_memory.py | 114 ++++++++++++------ .../providers/utils/memory/vector_store.py | 4 + 11 files changed, 240 insertions(+), 65 deletions(-) diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index a754a0818..520bdcbae 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -154,5 +154,5 @@ class MemoryBanks(Protocol): provider_memory_bank_id: Optional[str] = None, ) -> MemoryBank: ... - @webmethod(route="/memory_banks/delete", method="POST") - async def delete_memory_bank(self, memory_bank_id: str) -> None: ... + @webmethod(route="/memory_banks/unregister", method="POST") + async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index aa63ca541..3f4f683b3 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -82,7 +82,7 @@ class ModelsClient(Models): response.raise_for_status() return Model(**response.json()) - async def delete_model(self, model_id: str) -> None: + async def unregister_model(self, model_id: str) -> None: async with httpx.AsyncClient() as client: response = await client.delete( f"{self.base_url}/models/delete", diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 5ffcde52f..19dae12a0 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -64,5 +64,5 @@ class Models(Protocol): metadata: Optional[Dict[str, Any]] = None, ) -> Model: ... - @webmethod(route="/models/delete", method="POST") - async def delete_model(self, model_id: str) -> None: ... + @webmethod(route="/models/unregister", method="POST") + async def unregister_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index b0082db18..d0d588a91 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -51,6 +51,24 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable raise ValueError(f"Unknown API {api} for registering object with provider") +async def update_object_with_provider( + obj: RoutableObject, p: Any +) -> Optional[RoutableObject]: + api = get_impl_api(p) + if api == Api.memory: + return await p.update_memory_bank(obj) + else: + raise ValueError(f"Update not supported for {api}") + + +async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: + api = get_impl_api(p) + if api == Api.memory: + return await p.unregister_memory_bank(obj.identifier) + else: + raise ValueError(f"Unregister not supported for {api}") + + Registry = Dict[str, List[RoutableObjectWithProvider]] @@ -148,14 +166,16 @@ class CommonRoutingTableImpl(RoutingTable): return obj - async def delete_object(self, obj: RoutableObjectWithProvider) -> None: + async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - # TODO: delete from provider + await unregister_object_from_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) async def update_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: - registered_obj = await register_object_with_provider( + registered_obj = await update_object_with_provider( obj, self.impls_by_provider_id[obj.provider_id] ) return await self.dist_registry.update(registered_obj or obj) @@ -253,11 +273,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): registered_model = await self.update_object(updated_model) return registered_model - async def delete_model(self, model_id: str) -> None: + async def unregister_model(self, model_id: str) -> None: existing_model = await self.get_model(model_id) if existing_model is None: raise ValueError(f"Model {model_id} not found") - await self.delete_object(existing_model) + await self.unregister_object(existing_model) class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): @@ -358,11 +378,11 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): registered_bank = await self.update_object(updated_bank) return registered_bank - async def delete_memory_bank(self, memory_bank_id: str) -> None: + async def unregister_memory_bank(self, memory_bank_id: str) -> None: existing_bank = await self.get_memory_bank(memory_bank_id) if existing_bank is None: raise ValueError(f"Memory bank {memory_bank_id} not found") - await self.delete_object(existing_bank) + await self.unregister_object(existing_bank) class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 51ff163ab..05fc3a33a 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -55,6 +55,10 @@ class MemoryBanksProtocolPrivate(Protocol): async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ... + async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... + + async def update_memory_bank(self, memory_bank: MemoryBank) -> None: ... + class DatasetsProtocolPrivate(Protocol): async def register_dataset(self, dataset: Dataset) -> None: ... @@ -99,7 +103,6 @@ class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... -# TODO: this can now be inlined into RemoteProviderSpec @json_schema_type class AdapterSpec(BaseModel): adapter_type: str = Field( @@ -172,10 +175,12 @@ class RemoteProviderConfig(BaseModel): @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: AdapterSpec = Field( + adapter: Optional[AdapterSpec] = Field( + default=None, description=""" If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. +API responses, specify the adapter here. If not specified, it indicates the remote +as being "Llama Stack compatible" """, ) @@ -185,21 +190,38 @@ API responses, specify the adapter here. @property def module(self) -> str: - return self.adapter.module + if self.adapter: + return self.adapter.module + return "llama_stack.distribution.client" @property def pip_packages(self) -> List[str]: - return self.adapter.pip_packages + if self.adapter: + return self.adapter.pip_packages + return [] @property def provider_data_validator(self) -> Optional[str]: - return self.adapter.provider_data_validator + if self.adapter: + return self.adapter.provider_data_validator + return None -def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: - return RemoteProviderSpec( - api=api, - provider_type=f"remote::{adapter.adapter_type}", - config_class=adapter.config_class, - adapter=adapter, +def is_passthrough(spec: ProviderSpec) -> bool: + return isinstance(spec, RemoteProviderSpec) and spec.adapter is None + + +# Can avoid this by using Pydantic computed_field +def remote_provider_spec( + api: Api, adapter: Optional[AdapterSpec] = None +) -> RemoteProviderSpec: + config_class = ( + adapter.config_class + if adapter and adapter.config_class + else "llama_stack.distribution.datatypes.RemoteProviderConfig" + ) + provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote" + + return RemoteProviderSpec( + api=api, provider_type=provider_type, config_class=config_class, adapter=adapter ) diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 0790eb67d..9813f25ce 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import json import logging from typing import Any, Dict, List, Optional @@ -37,10 +39,57 @@ class FaissIndex(EmbeddingIndex): id_by_index: Dict[int, str] chunk_by_index: Dict[int, str] - def __init__(self, dimension: int): + def __init__(self, dimension: int, kvstore=None, bank_id: str = None): self.index = faiss.IndexFlatL2(dimension) self.id_by_index = {} self.chunk_by_index = {} + self.kvstore = kvstore + self.bank_id = bank_id + self.initialize() + + async def initialize(self) -> None: + if not self.kvstore or not self.bank_id: + return + + # Load existing index data from kvstore + index_key = f"faiss_index:v1::{self.bank_id}" + stored_data = await self.kvstore.get(index_key) + + if stored_data: + data = json.loads(stored_data) + self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()} + self.chunk_by_index = { + int(k): Chunk.model_validate_json(v) + for k, v in data["chunk_by_index"].items() + } + + # Load FAISS index + index_bytes = base64.b64decode(data["faiss_index"]) + self.index = faiss.deserialize_index(index_bytes) + + async def _save_index(self): + if not self.kvstore or not self.bank_id: + return + + # Serialize FAISS index + index_bytes = faiss.serialize_index(self.index) + + # Prepare data for storage + data = { + "id_by_index": self.id_by_index, + "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, + "faiss_index": base64.b64encode(index_bytes).decode(), + } + + # Store in kvstore + index_key = f"faiss_index:v1::{self.bank_id}" + await self.kvstore.set(key=index_key, value=json.dumps(data)) + + async def delete(self): + if not self.kvstore or not self.bank_id: + return + + await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") @tracing.span(name="add_chunks") async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): @@ -51,6 +100,9 @@ class FaissIndex(EmbeddingIndex): self.index.add(np.array(embeddings).astype(np.float32)) + # Save updated index + await self._save_index() + async def query( self, embedding: NDArray, k: int, score_threshold: float ) -> QueryDocumentsResponse: @@ -85,7 +137,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) index = BankWithIndex( - bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore) ) self.cache[bank.identifier] = index @@ -110,13 +162,28 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): # Store in cache index = BankWithIndex( - bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + bank=memory_bank, + index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore), ) self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + await self.cache[memory_bank_id].index.delete() + del self.cache[memory_bank_id] + await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}") + + async def update_memory_bank(self, memory_bank: MemoryBank) -> None: + # Not possible to update the index in place, so we delete and recreate + await self.cache[memory_bank.identifier].index.delete() + + self.cache[memory_bank.identifier] = BankWithIndex( + bank=memory_bank, + index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore), + ) + async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 0611d9aa2..b74945f03 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -67,6 +67,9 @@ class ChromaIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self): + await self.client.delete_collection(self.collection.name) + class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, url: str) -> None: @@ -134,6 +137,14 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): return [i.bank for i in self.cache.values()] + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + await self.cache[memory_bank_id].index.delete() + del self.cache[memory_bank_id] + + async def update_memory_bank(self, memory_bank: MemoryBank) -> None: + await self.unregister_memory_bank(memory_bank.identifier) + await self.register_memory_bank(memory_bank) + async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 9acfef2dc..8395e77b0 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -112,6 +112,9 @@ class PGVectorIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self): + self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") + class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: PGVectorConfig) -> None: @@ -177,6 +180,14 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): ) self.cache[memory_bank.identifier] = index + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + await self.cache[memory_bank_id].index.delete() + del self.cache[memory_bank_id] + + async def update_memory_bank(self, memory_bank: MemoryBank) -> None: + await self.unregister_memory_bank(memory_bank.identifier) + await self.register_memory_bank(memory_bank) + async def list_memory_banks(self) -> List[MemoryBank]: banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 97f0ac576..0f07badfa 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -54,4 +54,4 @@ class TestModelRegistration: assert updated_model.provider_resource_id != old_model.provider_resource_id # Cleanup - await models_impl.delete_model(model_id=model_id) + await models_impl.unregister_model(model_id=model_id) diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 24cef8a24..b6e2e0a76 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid + import pytest from llama_stack.apis.memory import * # noqa: F403 @@ -43,9 +45,10 @@ def sample_documents(): ] -async def register_memory_bank(banks_impl: MemoryBanks): +async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: + bank_id = f"test_bank_{uuid.uuid4().hex}" return await banks_impl.register_memory_bank( - memory_bank_id="test_bank", + memory_bank_id=bank_id, params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, @@ -57,43 +60,70 @@ async def register_memory_bank(banks_impl: MemoryBanks): class TestMemory: @pytest.mark.asyncio async def test_banks_list(self, memory_stack): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack + + # Register a test bank + registered_bank = await register_memory_bank(banks_impl) + + try: + # Verify our bank shows up in list + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert any( + bank.memory_bank_id == registered_bank.memory_bank_id + for bank in response + ) + finally: + # Clean up + await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id) + + # Verify our bank was removed response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 0 + assert all( + bank.memory_bank_id != registered_bank.memory_bank_id for bank in response + ) @pytest.mark.asyncio async def test_banks_register(self, memory_stack): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack - await banks_impl.register_memory_bank( - memory_bank_id="test_bank_no_provider", - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 + bank_id = f"test_bank_{uuid.uuid4().hex}" - # register same memory bank with same id again will fail - await banks_impl.register_memory_bank( - memory_bank_id="test_bank_no_provider", - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 + try: + # Register initial bank + await banks_impl.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) + + # Verify our bank exists + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert any(bank.memory_bank_id == bank_id for bank in response) + + # Try registering same bank again + await banks_impl.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) + + # Verify still only one instance of our bank + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert ( + len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1 + ) + finally: + # Clean up + await banks_impl.unregister_memory_bank(bank_id) @pytest.mark.asyncio async def test_query_documents(self, memory_stack, sample_documents): @@ -102,17 +132,23 @@ class TestMemory: with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - await register_memory_bank(banks_impl) - await memory_impl.insert_documents("test_bank", sample_documents) + registered_bank = await register_memory_bank(banks_impl) + await memory_impl.insert_documents( + registered_bank.memory_bank_id, sample_documents + ) query1 = "programming language" - response1 = await memory_impl.query_documents("test_bank", query1) + response1 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query1 + ) assert_valid_response(response1) assert any("Python" in chunk.content for chunk in response1.chunks) # Test case 3: Query with semantic similarity query3 = "AI and brain-inspired computing" - response3 = await memory_impl.query_documents("test_bank", query3) + response3 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query3 + ) assert_valid_response(response3) assert any( "neural networks" in chunk.content.lower() for chunk in response3.chunks @@ -121,14 +157,18 @@ class TestMemory: # Test case 4: Query with limit on number of results query4 = "computer" params4 = {"max_chunks": 2} - response4 = await memory_impl.query_documents("test_bank", query4, params4) + response4 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query4, params4 + ) assert_valid_response(response4) assert len(response4.chunks) <= 2 # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document params5 = {"score_threshold": 0.2} - response5 = await memory_impl.query_documents("test_bank", query5, params5) + response5 = await memory_impl.query_documents( + registered_bank.memory_bank_id, query5, params5 + ) assert_valid_response(response5) print("The scores are:", response5.scores) assert all(score >= 0.2 for score in response5.scores) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index ba7ed231e..2bbf6cdd2 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -145,6 +145,10 @@ class EmbeddingIndex(ABC): ) -> QueryDocumentsResponse: raise NotImplementedError() + @abstractmethod + async def delete(self): + raise NotImplementedError() + @dataclass class BankWithIndex: