From 8436d386d3d468004304bc7b12c7123f670bcd32 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 11 Nov 2024 16:22:11 -0800 Subject: [PATCH] use annotated union to parse --- llama_stack/apis/memory_banks/memory_banks.py | 10 ++++ .../distribution/routers/routing_tables.py | 18 ++++-- .../distribution/utils/memory_bank_utils.py | 60 ------------------- 3 files changed, 22 insertions(+), 66 deletions(-) delete mode 100644 llama_stack/distribution/utils/memory_bank_utils.py diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 48064af86..a5e985a25 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -90,6 +90,16 @@ class GraphMemoryBankParams(BaseModel): memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value +AnyMemoryBank = Annotated[ + Union[ + VectorMemoryBank, + KeyValueMemoryBank, + KeywordMemoryBank, + GraphMemoryBank, + ], + Field(discriminator="memory_bank_type"), +] + BankParams = Annotated[ Union[ VectorMemoryBankParams, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 7174addcd..388a131b7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,6 +6,8 @@ from typing import Any, Dict, List, Optional +from pydantic import parse_obj_as + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 @@ -17,7 +19,6 @@ from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.utils.memory_bank_utils import build_memory_bank def get_impl_api(p: Any) -> Api: @@ -286,11 +287,16 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - memory_bank = build_memory_bank( - memory_bank_id, - params, - provider_id, - provider_memorybank_id, + memory_bank = parse_obj_as( + AnyMemoryBank, + { + "identifier": memory_bank_id, + "type": ResourceType.memory_bank.value, + "provider_id": provider_id, + "provider_resource_id": provider_memorybank_id, + "memory_bank_type": params.memory_bank_type, + **params.model_dump(exclude={"memory_bank_type"}), + }, ) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/utils/memory_bank_utils.py b/llama_stack/distribution/utils/memory_bank_utils.py deleted file mode 100644 index aad0b6cf7..000000000 --- a/llama_stack/distribution/utils/memory_bank_utils.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.apis.memory_banks.memory_banks import ( - BankParams, - GraphMemoryBank, - KeyValueMemoryBank, - KeywordMemoryBank, - MemoryBank, - MemoryBankType, - VectorMemoryBank, - VectorMemoryBankParams, -) - - -def build_memory_bank( - memory_bank_id: str, - params: BankParams, - provider_id: str, - provider_memorybank_id: str, -) -> MemoryBank: - if params.memory_bank_type == MemoryBankType.vector.value: - assert isinstance(params, VectorMemoryBankParams) - memory_bank = VectorMemoryBank( - identifier=memory_bank_id, - provider_id=provider_id, - provider_resource_id=provider_memorybank_id, - memory_bank_type=params.memory_bank_type, - embedding_model=params.embedding_model, - chunk_size_in_tokens=params.chunk_size_in_tokens, - overlap_size_in_tokens=params.overlap_size_in_tokens, - ) - elif params.memory_bank_type == MemoryBankType.keyvalue.value: - memory_bank = KeyValueMemoryBank( - identifier=memory_bank_id, - provider_id=provider_id, - provider_resource_id=provider_memorybank_id, - memory_bank_type=params.memory_bank_type, - ) - elif params.memory_bank_type == MemoryBankType.keyword.value: - memory_bank = KeywordMemoryBank( - identifier=memory_bank_id, - provider_id=provider_id, - provider_resource_id=provider_memorybank_id, - memory_bank_type=params.memory_bank_type, - ) - elif params.memory_bank_type == MemoryBankType.graph.value: - memory_bank = GraphMemoryBank( - identifier=memory_bank_id, - provider_id=provider_id, - provider_resource_id=provider_memorybank_id, - memory_bank_type=params.memory_bank_type, - ) - else: - raise ValueError(f"Unknown memory bank type: {params.memory_bank_type}") - return memory_bank