From 24b914f0fe811b6fe4717364bedea462399e45f4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 8 Nov 2024 16:51:05 -0800 Subject: [PATCH] address feedback --- llama_stack/apis/memory_banks/memory_banks.py | 19 ++++++++++++++----- llama_stack/providers/datatypes.py | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index a4895c4c9..c85e0fc25 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -5,11 +5,19 @@ # the root directory of this source tree. from enum import Enum -from typing import List, Literal, Optional, Protocol, runtime_checkable +from typing import ( + Annotated, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel +from pydantic import BaseModel, Field from llama_stack.apis.resource import Resource, ResourceType @@ -63,9 +71,10 @@ class VectorMemoryBankParams(BaseModel): overlap_size_in_tokens: Optional[int] = None -BankParams = VectorMemoryBankParams # For now, since we only have one type of params -# If you need to add more types later, you can do: -# BankParams = Union[VectorMemoryBankParams, KeyValueMemoryBankParams, KeywordMemoryBankParams, GraphMemoryBankParams] +BankParams = Annotated[ + Union[VectorMemoryBankParams], + Field(discriminator="type"), +] @runtime_checkable diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index f9b2774ba..ed2033494 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef from llama_stack.apis.eval_tasks import EvalTaskDef +from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.shields import Shield