memory banks

This commit is contained in:
Ashwin Bharambe 2024-07-10 23:27:17 -07:00
parent 6fb69efbe5
commit ee86f2c75f
4 changed files with 619 additions and 32 deletions

View file

@ -7,7 +7,8 @@ import yaml
from agentic_system_types import (
AgenticSystemTurn,
ExecutionStepType,
IndexedMemoryDocument,
MemoryBank,
MemoryBankDocument,
SafetyViolation,
)
@ -172,6 +173,8 @@ class BatchInference(Protocol):
@dataclass
class AgenticSystemCreateRequest:
uuid: str
instructions: str
model: InstructModel
@ -182,6 +185,8 @@ class AgenticSystemCreateRequest:
# execute themselves.
executable_tools: Set[str] = field(default_factory=set)
memory_bank_uuids: List[str] = field(default_factory=list)
input_shields: List[ShieldConfig] = field(default_factory=list)
output_shields: List[ShieldConfig] = field(default_factory=list)
@ -189,13 +194,13 @@ class AgenticSystemCreateRequest:
@json_schema_type
@dataclass
class AgenticSystemCreateResponse:
agent_id: str
agent_uuid: str
@json_schema_type
@dataclass
class AgenticSystemExecuteRequest:
agent_id: str
agent_uuid: str
messages: List[Message]
turn_history: List[AgenticSystemTurn] = None
stream: bool = False
@ -227,11 +232,12 @@ class AgenticSystemExecuteResponseStreamChunk:
step_uuid: str
step_type: ExecutionStepType
# TODO(ashwin): maybe add more structure here and do this as a proper tagged union
violation: Optional[SafetyViolation] = None
tool_call: Optional[ToolCall] = None
tool_response_delta: Optional[ToolResponse] = None
response_text_delta: Optional[str] = None
retrieved_document: Optional[IndexedMemoryDocument] = None
retrieved_document: Optional[MemoryBankDocument] = None
stop_reason: Optional[StopReason] = None
@ -259,6 +265,41 @@ class AgenticSystem(Protocol):
) -> None: ...
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/create")
def create_memory_bank(
self,
bank_uuid: str,
bank_name: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_banks/get")
def get_memory_banks(
self,
) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/insert")
def post_insert_memory_documents(
self,
bank_uuid: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_banks/delete")
def post_delete_memory_documents(
self,
bank_uuid: str,
document_uuids: List[str],
) -> None: ...
@webmethod(route="/memory_banks/drop")
def remove_memory_bank(
self,
bank_uuid: str,
) -> None: ...
@dataclass
class KPromptGenerations:
prompt: Message
@ -456,6 +497,7 @@ class LlamaStackEndpoints(
SyntheticDataGeneration,
Datasets,
Finetuning,
MemoryBanks,
): ...