diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index a5e985a25..303104f25 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -31,13 +31,8 @@ class MemoryBankType(Enum): @json_schema_type -class MemoryBank(Resource): +class VectorMemoryBank(Resource): type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - memory_bank_type: MemoryBankType - - -@json_schema_type -class VectorMemoryBank(MemoryBank): memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int @@ -45,21 +40,24 @@ class VectorMemoryBank(MemoryBank): @json_schema_type -class KeyValueMemoryBank(MemoryBank): +class KeyValueMemoryBank(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( MemoryBankType.keyvalue.value ) @json_schema_type -class KeywordMemoryBank(MemoryBank): +class KeywordMemoryBank(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value memory_bank_type: Literal[MemoryBankType.keyword.value] = ( MemoryBankType.keyword.value ) @json_schema_type -class GraphMemoryBank(MemoryBank): +class GraphMemoryBank(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value @@ -90,7 +88,7 @@ class GraphMemoryBankParams(BaseModel): memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value -AnyMemoryBank = Annotated[ +MemoryBank = Annotated[ Union[ VectorMemoryBank, KeyValueMemoryBank, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 388a131b7..aa61580b2 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -288,14 +288,13 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): "No provider specified and multiple providers available. Please specify a provider_id." ) memory_bank = parse_obj_as( - AnyMemoryBank, + MemoryBank, { "identifier": memory_bank_id, "type": ResourceType.memory_bank.value, "provider_id": provider_id, "provider_resource_id": provider_memorybank_id, - "memory_bank_type": params.memory_bank_type, - **params.model_dump(exclude={"memory_bank_type"}), + **params.model_dump(), }, ) await self.register_object(memory_bank)