diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index b1bb44f1c..a4895c4c9 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import List, Literal, Optional, Protocol, runtime_checkable, Union +from typing import List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -56,76 +56,16 @@ class GraphMemoryBank(MemoryBank): @json_schema_type -class BaseRegistration(BaseModel): - memory_bank_id: str - provider_resource_id: Optional[str] = None - provider_id: Optional[str] = None - - -@json_schema_type -class VectorRegistration(BaseRegistration): +class VectorMemoryBankParams(BaseModel): + type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int overlap_size_in_tokens: Optional[int] = None -@json_schema_type -class KeyValueRegistration(BaseRegistration): - pass - - -@json_schema_type -class KeywordRegistration(BaseRegistration): - pass - - -@json_schema_type -class GraphRegistration(BaseRegistration): - pass - - -RegistrationRequest = Union[ - VectorRegistration, - KeyValueRegistration, - KeywordRegistration, - GraphRegistration, -] - - -def registration_request_to_memory_bank(request: RegistrationRequest) -> MemoryBank: - """Convert registration request to memory bank object""" - if isinstance(request, VectorRegistration): - return VectorMemoryBank( - identifier=request.memory_bank_id, - provider_resource_id=request.provider_resource_id, - provider_id=request.provider_id, - embedding_model=request.embedding_model, - chunk_size_in_tokens=request.chunk_size_in_tokens, - overlap_size_in_tokens=request.overlap_size_in_tokens, - ) - elif isinstance(request, KeyValueRegistration): - return KeyValueMemoryBank( - identifier=request.memory_bank_id, - provider_resource_id=request.provider_resource_id, - provider_id=request.provider_id, - memory_bank_type=MemoryBankType.keyvalue, - ) - elif isinstance(request, KeywordRegistration): - return KeywordMemoryBank( - identifier=request.memory_bank_id, - provider_resource_id=request.provider_resource_id, - provider_id=request.provider_id, - memory_bank_type=MemoryBankType.keyword, - ) - elif isinstance(request, GraphRegistration): - return GraphMemoryBank( - identifier=request.memory_bank_id, - provider_resource_id=request.provider_resource_id, - provider_id=request.provider_id, - memory_bank_type=MemoryBankType.graph, - ) - else: - raise ValueError(f"Unknown registration type: {type(request)}") +BankParams = VectorMemoryBankParams # For now, since we only have one type of params +# If you need to add more types later, you can do: +# BankParams = Union[VectorMemoryBankParams, KeyValueMemoryBankParams, KeywordMemoryBankParams, GraphMemoryBankParams] @runtime_checkable @@ -138,5 +78,10 @@ class MemoryBanks(Protocol): @webmethod(route="/memory_banks/register", method="POST") async def register_memory_bank( - self, request: RegistrationRequest + self, + memory_bank_id: str, + memory_bank_type: MemoryBankType, + provider_id: Optional[str] = None, + provider_memorybank_id: Optional[str] = None, + params: Optional[BankParams] = None, ) -> MemoryBank: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 7152dd6eb..81249f9f3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,8 +7,8 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack.apis.datasetio.datasetio import DatasetIO +from llama_stack.apis.memory_banks.memory_banks import BankParams, MemoryBankType from llama_stack.distribution.datatypes import RoutingTable - from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 @@ -34,9 +34,19 @@ class MemoryRouter(Memory): async def register_memory_bank( self, - request: RegistrationRequest, + memory_bank_id: str, + memory_bank_type: MemoryBankType, + provider_id: Optional[str] = None, + provider_memorybank_id: Optional[str] = None, + params: Optional[BankParams] = None, ) -> None: - await self.routing_table.register_memory_bank(request) + await self.routing_table.register_memory_bank( + memory_bank_id, + memory_bank_type, + provider_id, + provider_memorybank_id, + params, + ) async def insert_documents( self, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 95f92a7db..676ce14f6 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -17,6 +17,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.utils.memory_bank_utils import build_memory_bank def get_impl_api(p: Any) -> Api: @@ -272,19 +273,24 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def register_memory_bank( self, - request: RegistrationRequest, + memory_bank_id: str, + provider_id: str, + provider_memorybank_id: str, + params: BankParams, ) -> MemoryBank: - if request.provider_resource_id is None: - request.provider_resource_id = request.memory_bank_id - if request.provider_id is None: + if provider_memorybank_id is None: + provider_memorybank_id = memory_bank_id + if provider_id is None: # If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) == 1: - request.provider_id = list(self.impls_by_provider_id.keys())[0] + provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - memory_bank = registration_request_to_memory_bank(request) + memory_bank = build_memory_bank( + memory_bank_id, params.type, provider_id, provider_memorybank_id, params + ) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/utils/memory_bank_utils.py b/llama_stack/distribution/utils/memory_bank_utils.py new file mode 100644 index 000000000..b2d698257 --- /dev/null +++ b/llama_stack/distribution/utils/memory_bank_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_stack.apis.memory_banks.memory_banks import ( + BankParams, + GraphMemoryBank, + KeyValueMemoryBank, + KeywordMemoryBank, + MemoryBank, + MemoryBankType, + VectorMemoryBank, + VectorMemoryBankParams, +) + + +def build_memory_bank( + memory_bank_id: str, + memory_bank_type: MemoryBankType, + provider_id: str, + provider_memorybank_id: str, + params: Optional[BankParams] = None, +) -> MemoryBank: + if memory_bank_type == MemoryBankType.vector: + assert isinstance(params, VectorMemoryBankParams) + memory_bank = VectorMemoryBank( + identifier=memory_bank_id, + provider_id=provider_id, + provider_resource_id=provider_memorybank_id, + memory_bank_type=memory_bank_type, + embedding_model=params.embedding_model, + chunk_size_in_tokens=params.chunk_size_in_tokens, + overlap_size_in_tokens=params.overlap_size_in_tokens, + ) + elif memory_bank_type == MemoryBankType.keyvalue: + memory_bank = KeyValueMemoryBank( + identifier=memory_bank_id, + provider_id=provider_id, + provider_resource_id=provider_memorybank_id, + memory_bank_type=memory_bank_type, + ) + elif memory_bank_type == MemoryBankType.keyword: + memory_bank = KeywordMemoryBank( + identifier=memory_bank_id, + provider_id=provider_id, + provider_resource_id=provider_memorybank_id, + memory_bank_type=memory_bank_type, + ) + elif memory_bank_type == MemoryBankType.graph: + memory_bank = GraphMemoryBank( + identifier=memory_bank_id, + provider_id=provider_id, + provider_resource_id=provider_memorybank_id, + memory_bank_type=memory_bank_type, + ) + else: + raise ValueError(f"Unknown memory bank type: {memory_bank_type}") + return memory_bank