From 4369c2d0b6875a4627488dd9a82a1f1b5d6665b5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 11 Nov 2024 11:49:07 -0800 Subject: [PATCH] change register signature to make params required --- llama_stack/apis/memory_banks/memory_banks.py | 33 ++++++++++++++++--- llama_stack/distribution/routers/routers.py | 8 ++--- .../distribution/routers/routing_tables.py | 6 ++-- .../distribution/utils/memory_bank_utils.py | 22 ++++++------- .../providers/tests/memory/test_memory.py | 3 -- 5 files changed, 43 insertions(+), 29 deletions(-) diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index c85e0fc25..48064af86 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -65,15 +65,39 @@ class GraphMemoryBank(MemoryBank): @json_schema_type class VectorMemoryBankParams(BaseModel): - type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value + memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int overlap_size_in_tokens: Optional[int] = None +@json_schema_type +class KeyValueMemoryBankParams(BaseModel): + memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( + MemoryBankType.keyvalue.value + ) + + +@json_schema_type +class KeywordMemoryBankParams(BaseModel): + memory_bank_type: Literal[MemoryBankType.keyword.value] = ( + MemoryBankType.keyword.value + ) + + +@json_schema_type +class GraphMemoryBankParams(BaseModel): + memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value + + BankParams = Annotated[ - Union[VectorMemoryBankParams], - Field(discriminator="type"), + Union[ + VectorMemoryBankParams, + KeyValueMemoryBankParams, + KeywordMemoryBankParams, + GraphMemoryBankParams, + ], + Field(discriminator="memory_bank_type"), ] @@ -89,8 +113,7 @@ class MemoryBanks(Protocol): async def register_memory_bank( self, memory_bank_id: str, - memory_bank_type: MemoryBankType, + params: BankParams, provider_id: Optional[str] = None, provider_memorybank_id: Optional[str] = None, - params: Optional[BankParams] = None, ) -> MemoryBank: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 81249f9f3..5f6395e0d 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack.apis.datasetio.datasetio import DatasetIO -from llama_stack.apis.memory_banks.memory_banks import BankParams, MemoryBankType +from llama_stack.apis.memory_banks.memory_banks import BankParams from llama_stack.distribution.datatypes import RoutingTable from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -35,17 +35,15 @@ class MemoryRouter(Memory): async def register_memory_bank( self, memory_bank_id: str, - memory_bank_type: MemoryBankType, + params: BankParams, provider_id: Optional[str] = None, provider_memorybank_id: Optional[str] = None, - params: Optional[BankParams] = None, ) -> None: await self.routing_table.register_memory_bank( memory_bank_id, - memory_bank_type, + params, provider_id, provider_memorybank_id, - params, ) async def insert_documents( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index a23051c6d..7174addcd 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -272,10 +272,9 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def register_memory_bank( self, memory_bank_id: str, - memory_bank_type: MemoryBankType, + params: BankParams, provider_id: Optional[str] = None, provider_memorybank_id: Optional[str] = None, - params: Optional[BankParams] = None, ) -> MemoryBank: if provider_memorybank_id is None: provider_memorybank_id = memory_bank_id @@ -289,10 +288,9 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): ) memory_bank = build_memory_bank( memory_bank_id, - memory_bank_type, + params, provider_id, provider_memorybank_id, - params, ) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/utils/memory_bank_utils.py b/llama_stack/distribution/utils/memory_bank_utils.py index e55977e28..aad0b6cf7 100644 --- a/llama_stack/distribution/utils/memory_bank_utils.py +++ b/llama_stack/distribution/utils/memory_bank_utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional from llama_stack.apis.memory_banks.memory_banks import ( BankParams, @@ -20,43 +19,42 @@ from llama_stack.apis.memory_banks.memory_banks import ( def build_memory_bank( memory_bank_id: str, - memory_bank_type: MemoryBankType, + params: BankParams, provider_id: str, provider_memorybank_id: str, - params: Optional[BankParams] = None, ) -> MemoryBank: - if memory_bank_type == MemoryBankType.vector.value: + if params.memory_bank_type == MemoryBankType.vector.value: assert isinstance(params, VectorMemoryBankParams) memory_bank = VectorMemoryBank( identifier=memory_bank_id, provider_id=provider_id, provider_resource_id=provider_memorybank_id, - memory_bank_type=memory_bank_type, + memory_bank_type=params.memory_bank_type, embedding_model=params.embedding_model, chunk_size_in_tokens=params.chunk_size_in_tokens, overlap_size_in_tokens=params.overlap_size_in_tokens, ) - elif memory_bank_type == MemoryBankType.keyvalue.value: + elif params.memory_bank_type == MemoryBankType.keyvalue.value: memory_bank = KeyValueMemoryBank( identifier=memory_bank_id, provider_id=provider_id, provider_resource_id=provider_memorybank_id, - memory_bank_type=memory_bank_type, + memory_bank_type=params.memory_bank_type, ) - elif memory_bank_type == MemoryBankType.keyword.value: + elif params.memory_bank_type == MemoryBankType.keyword.value: memory_bank = KeywordMemoryBank( identifier=memory_bank_id, provider_id=provider_id, provider_resource_id=provider_memorybank_id, - memory_bank_type=memory_bank_type, + memory_bank_type=params.memory_bank_type, ) - elif memory_bank_type == MemoryBankType.graph.value: + elif params.memory_bank_type == MemoryBankType.graph.value: memory_bank = GraphMemoryBank( identifier=memory_bank_id, provider_id=provider_id, provider_resource_id=provider_memorybank_id, - memory_bank_type=memory_bank_type, + memory_bank_type=params.memory_bank_type, ) else: - raise ValueError(f"Unknown memory bank type: {memory_bank_type}") + raise ValueError(f"Unknown memory bank type: {params.memory_bank_type}") return memory_bank diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 1cefd1d4a..a1befa6b0 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -47,7 +47,6 @@ async def register_memory_bank(banks_impl: MemoryBanks): return await banks_impl.register_memory_bank( memory_bank_id="test_bank", - memory_bank_type="vector", params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, @@ -74,7 +73,6 @@ class TestMemory: bank = await banks_impl.register_memory_bank( memory_bank_id="test_bank_no_provider", - memory_bank_type="vector", params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, @@ -88,7 +86,6 @@ class TestMemory: # register same memory bank with same id again will fail await banks_impl.register_memory_bank( memory_bank_id="test_bank_no_provider", - memory_bank_type="vector", params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512,