mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 05:24:39 +00:00
agentic loop has a RAG implementation
This commit is contained in:
parent
77d6055d9f
commit
14637bea66
4 changed files with 245 additions and 111 deletions
|
@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
|
@ -21,7 +22,7 @@ from llama_toolchain.memory.api import * # noqa: F403
|
|||
|
||||
@json_schema_type
|
||||
class Attachment(BaseModel):
|
||||
url: URL
|
||||
content: InterleavedTextMedia | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
|
@ -81,10 +82,45 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
|
|||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
|
||||
|
||||
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
AgenticSystemVectorMemoryBankConfig,
|
||||
AgenticSystemKeyValueMemoryBankConfig,
|
||||
AgenticSystemKeywordMemoryBankConfig,
|
||||
AgenticSystemGraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||
memory_banks: List[MemoryBank] = Field(default_factory=list)
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 10
|
||||
|
||||
|
||||
AgenticSystemToolDefinition = Annotated[
|
||||
|
@ -141,8 +177,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
StepType.memory_retrieval.value
|
||||
)
|
||||
memory_bank_ids: List[str]
|
||||
documents: List[MemoryBankDocument]
|
||||
scores: List[float]
|
||||
inserted_context: InterleavedTextMedia
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
|
@ -185,38 +220,7 @@ class Session(BaseModel):
|
|||
turns: List[Turn]
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
|
@ -236,7 +240,6 @@ class AgentConfigCommon(BaseModel):
|
|||
class AgentConfig(AgentConfigCommon):
|
||||
model: str
|
||||
instructions: str
|
||||
memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue