flesh out memory banks API

This commit is contained in:
Ashwin Bharambe 2024-08-23 06:38:15 -07:00
parent 31289e3f47
commit 77d6055d9f
11 changed files with 1792 additions and 974 deletions

View file

@ -10,13 +10,13 @@ from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api.datatypes import * # noqa: F403
from llama_toolchain.memory.api.datatypes import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
@json_schema_type
@ -25,31 +25,81 @@ class Attachment(BaseModel):
mime_type: str
class AgenticSystemBuiltinTool(BuiltinTool):
class AgenticSystemTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
function_call = "function_call"
memory = "memory"
@json_schema_type
class AgenticSystemToolDefinition(BaseModel):
tool_name: Union[AgenticSystemBuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@validator("tool_name", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return AgenticSystemBuiltinTool(v)
except ValueError:
return v
return v
execution_config: Optional[RestAPIExecutionConfig] = None
class ToolDefinitionCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@json_schema_type
class BraveSearchToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.brave_search.value] = (
AgenticSystemTool.brave_search.value
)
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
AgenticSystemTool.wolfram_alpha.value
)
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.code_interpreter.value] = (
AgenticSystemTool.code_interpreter.value
)
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.function_call.value] = (
AgenticSystemTool.function_call.value
)
description: str
parameters: Dict[str, ToolParamDefinition]
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_banks: List[MemoryBank] = Field(default_factory=list)
AgenticSystemToolDefinition = Annotated[
Union[
BraveSearchToolDefinition,
WolframAlphaToolDefinition,
PhotogenToolDefinition,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
MemoryToolDefinition,
],
Field(discriminator="type"),
]
class StepCommon(BaseModel):
turn_id: str
step_id: str
@ -136,27 +186,45 @@ class Session(BaseModel):
started_at: datetime
@json_schema_type
class MemoryConfig(BaseModel):
memory_bank_id: str
class MemoryBankConfigCommon(BaseModel):
bank_id: str
# this configuration can hold other information we may want to configure
# how will the agent use the memory bank API?
#
#
class VectorMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class KeyValueMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
memory_configs: Optional[List[MemoryConfig]] = Field(default_factory=list)
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
# if you completely want to replace the messages prefixed by the system,
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
@ -168,6 +236,7 @@ class AgentConfigCommon(BaseModel):
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list)
class AgentConfigOverridablePerTurn(AgentConfigCommon):

View file

@ -10,21 +10,9 @@ from typing import Protocol
from llama_models.schema_utils import json_schema_type, webmethod
@json_schema_type
class AgenticSystemCreateRequest(BaseModel):
agent_config: AgentConfig
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
# TODO: rename this to agent_id
system_id: str
@json_schema_type
class AgenticSystemSessionCreateRequest(BaseModel):
system_id: str
session_name: str
agent_id: str
@json_schema_type
@ -33,8 +21,8 @@ class AgenticSystemSessionCreateResponse(BaseModel):
@json_schema_type
class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
system_id: str
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
@ -67,7 +55,7 @@ class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
self,
request: AgenticSystemCreateRequest,
agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
@ -91,7 +79,8 @@ class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/session/get")