agentic loop has a RAG implementation

This commit is contained in:
Ashwin Bharambe 2024-08-23 15:20:40 -07:00
parent 77d6055d9f
commit 14637bea66
4 changed files with 245 additions and 111 deletions

View file

@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
@ -21,7 +22,7 @@ from llama_toolchain.memory.api import * # noqa: F403
@json_schema_type
class Attachment(BaseModel):
url: URL
content: InterleavedTextMedia | URL
mime_type: str
@ -81,10 +82,45 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgenticSystemVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_banks: List[MemoryBank] = Field(default_factory=list)
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[
@ -141,8 +177,7 @@ class MemoryRetrievalStep(StepCommon):
StepType.memory_retrieval.value
)
memory_bank_ids: List[str]
documents: List[MemoryBankDocument]
scores: List[float]
inserted_context: InterleavedTextMedia
Step = Annotated[
@ -185,38 +220,7 @@ class Session(BaseModel):
turns: List[Turn]
started_at: datetime
class MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class KeyValueMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
memory_bank: Optional[MemoryBank] = None
class AgentConfigCommon(BaseModel):
@ -236,7 +240,6 @@ class AgentConfigCommon(BaseModel):
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list)
class AgentConfigOverridablePerTurn(AgentConfigCommon):