flesh out memory banks API

This commit is contained in:
Ashwin Bharambe 2024-08-23 06:38:15 -07:00
parent 31289e3f47
commit 77d6055d9f
11 changed files with 1792 additions and 974 deletions

View file

@ -3,23 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class MemoryBank(BaseModel):
memory_bank_id: str
memory_bank_name: str
@json_schema_type
class MemoryBankDocument(BaseModel):
document_id: str
content: bytes
metadata: Dict[str, Any]
mime_type: str

View file

@ -6,76 +6,132 @@
from typing import List, Protocol
from llama_models.llama3.api.datatypes import InterleavedTextMedia
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import webmethod
from .datatypes import * # noqa: F403
@json_schema_type
class RetrieveMemoryDocumentsRequest(BaseModel):
query: InterleavedTextMedia
bank_ids: str
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedTextMedia | URL
mime_type: str
metadata: Dict[str, Any]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@json_schema_type
class RetrieveMemoryDocumentsResponse(BaseModel):
documents: List[MemoryBankDocument]
class QueryDocumentsResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
class KeyValueMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
class KeywordMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing store which obeys the MemoryBank REST interface
url: Optional[URL] = None
class Memory(Protocol):
@webmethod(route="/memory_banks/create")
def create_memory_bank(
self,
bank_id: str,
bank_name: str,
embedding_model: str,
documents: List[MemoryBankDocument],
) -> None: ...
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/list", method="GET")
def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
@webmethod(route="/memory_banks/drop")
def delete_memory_bank(
@webmethod(route="/memory_banks/drop", method="DELETE")
def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
@webmethod(route="/memory_bank/insert")
def insert_memory_documents(
def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/update")
def update_memory_documents(
def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/get")
def retrieve_memory_documents(
self,
request: RetrieveMemoryDocumentsRequest,
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/get")
def get_memory_documents(
@webmethod(route="/memory_bank/query")
def query_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[MemoryBankDocument]: ...
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/delete")
def delete_memory_documents(
@webmethod(route="/memory_bank/documents/get")
def get_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[str]: ...
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete")
def delete_documents(
self,
bank_id: str,
document_ids: List[str],
) -> None: ...