use annotated union to parse

This commit is contained in:
Dinesh Yeduguru 2024-11-11 16:22:11 -08:00
parent cf416cb6b8
commit 8436d386d3
3 changed files with 22 additions and 66 deletions

View file

@ -90,6 +90,16 @@ class GraphMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
AnyMemoryBank = Annotated[
Union[
VectorMemoryBank,
KeyValueMemoryBank,
KeywordMemoryBank,
GraphMemoryBank,
],
Field(discriminator="memory_bank_type"),
]
BankParams = Annotated[
Union[
VectorMemoryBankParams,

View file

@ -6,6 +6,8 @@
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -17,7 +19,6 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.utils.memory_bank_utils import build_memory_bank
def get_impl_api(p: Any) -> Api:
@ -286,11 +287,16 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
memory_bank = build_memory_bank(
memory_bank_id,
params,
provider_id,
provider_memorybank_id,
memory_bank = parse_obj_as(
AnyMemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memorybank_id,
"memory_bank_type": params.memory_bank_type,
**params.model_dump(exclude={"memory_bank_type"}),
},
)
await self.register_object(memory_bank)
return memory_bank

View file

@ -1,60 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.memory_banks.memory_banks import (
BankParams,
GraphMemoryBank,
KeyValueMemoryBank,
KeywordMemoryBank,
MemoryBank,
MemoryBankType,
VectorMemoryBank,
VectorMemoryBankParams,
)
def build_memory_bank(
memory_bank_id: str,
params: BankParams,
provider_id: str,
provider_memorybank_id: str,
) -> MemoryBank:
if params.memory_bank_type == MemoryBankType.vector.value:
assert isinstance(params, VectorMemoryBankParams)
memory_bank = VectorMemoryBank(
identifier=memory_bank_id,
provider_id=provider_id,
provider_resource_id=provider_memorybank_id,
memory_bank_type=params.memory_bank_type,
embedding_model=params.embedding_model,
chunk_size_in_tokens=params.chunk_size_in_tokens,
overlap_size_in_tokens=params.overlap_size_in_tokens,
)
elif params.memory_bank_type == MemoryBankType.keyvalue.value:
memory_bank = KeyValueMemoryBank(
identifier=memory_bank_id,
provider_id=provider_id,
provider_resource_id=provider_memorybank_id,
memory_bank_type=params.memory_bank_type,
)
elif params.memory_bank_type == MemoryBankType.keyword.value:
memory_bank = KeywordMemoryBank(
identifier=memory_bank_id,
provider_id=provider_id,
provider_resource_id=provider_memorybank_id,
memory_bank_type=params.memory_bank_type,
)
elif params.memory_bank_type == MemoryBankType.graph.value:
memory_bank = GraphMemoryBank(
identifier=memory_bank_id,
provider_id=provider_id,
provider_resource_id=provider_memorybank_id,
memory_bank_type=params.memory_bank_type,
)
else:
raise ValueError(f"Unknown memory bank type: {params.memory_bank_type}")
return memory_bank