use annotated union to parse

This commit is contained in:
Dinesh Yeduguru 2024-11-11 16:22:11 -08:00
parent cf416cb6b8
commit 8436d386d3
3 changed files with 22 additions and 66 deletions

View file

@ -6,6 +6,8 @@
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -17,7 +19,6 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.utils.memory_bank_utils import build_memory_bank
def get_impl_api(p: Any) -> Api:
@ -286,11 +287,16 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
memory_bank = build_memory_bank(
memory_bank_id,
params,
provider_id,
provider_memorybank_id,
memory_bank = parse_obj_as(
AnyMemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memorybank_id,
"memory_bank_type": params.memory_bank_type,
**params.model_dump(exclude={"memory_bank_type"}),
},
)
await self.register_object(memory_bank)
return memory_bank