Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -22,7 +22,8 @@ class MemoryBankType(Enum):
class CommonDef(BaseModel):
identifier: str
provider_id: Optional[str] = None
# Hack: move this out later
provider_id: str = ""
@json_schema_type
@ -58,13 +59,20 @@ MemoryBankDef = Annotated[
Field(discriminator="type"),
]
MemoryBankDefWithProvider = MemoryBankDef
@runtime_checkable
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]: ...
@webmethod(route="/memory_banks/register", method="POST")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None: ...