mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
agentic loop has a RAG implementation
This commit is contained in:
parent
77d6055d9f
commit
14637bea66
4 changed files with 245 additions and 111 deletions
|
@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.safety.api import * # noqa: F403
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
|
@ -21,7 +22,7 @@ from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
url: URL
|
content: InterleavedTextMedia | URL
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,10 +82,45 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _MemoryBankConfigCommon(BaseModel):
|
||||||
|
bank_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
AgenticSystemVectorMemoryBankConfig,
|
||||||
|
AgenticSystemKeyValueMemoryBankConfig,
|
||||||
|
AgenticSystemKeywordMemoryBankConfig,
|
||||||
|
AgenticSystemGraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||||
memory_banks: List[MemoryBank] = Field(default_factory=list)
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
|
max_tokens_in_context: int = 4096
|
||||||
|
max_chunks: int = 10
|
||||||
|
|
||||||
|
|
||||||
AgenticSystemToolDefinition = Annotated[
|
AgenticSystemToolDefinition = Annotated[
|
||||||
|
@ -141,8 +177,7 @@ class MemoryRetrievalStep(StepCommon):
|
||||||
StepType.memory_retrieval.value
|
StepType.memory_retrieval.value
|
||||||
)
|
)
|
||||||
memory_bank_ids: List[str]
|
memory_bank_ids: List[str]
|
||||||
documents: List[MemoryBankDocument]
|
inserted_context: InterleavedTextMedia
|
||||||
scores: List[float]
|
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
Step = Annotated[
|
||||||
|
@ -185,38 +220,7 @@ class Session(BaseModel):
|
||||||
turns: List[Turn]
|
turns: List[Turn]
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
memory_bank: Optional[MemoryBank] = None
|
||||||
class MemoryBankConfigCommon(BaseModel):
|
|
||||||
bank_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class VectorMemoryBankConfig(MemoryBankConfigCommon):
|
|
||||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
||||||
|
|
||||||
|
|
||||||
class KeyValueMemoryBankConfig(MemoryBankConfigCommon):
|
|
||||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
|
||||||
keys: List[str] # what keys to focus on
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordMemoryBankConfig(MemoryBankConfigCommon):
|
|
||||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMemoryBankConfig(MemoryBankConfigCommon):
|
|
||||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
||||||
entities: List[str] # what entities to focus on
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
VectorMemoryBankConfig,
|
|
||||||
KeyValueMemoryBankConfig,
|
|
||||||
KeywordMemoryBankConfig,
|
|
||||||
GraphMemoryBankConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
|
@ -236,7 +240,6 @@ class AgentConfigCommon(BaseModel):
|
||||||
class AgentConfig(AgentConfigCommon):
|
class AgentConfig(AgentConfigCommon):
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
|
|
|
@ -119,6 +119,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
steps = []
|
steps = []
|
||||||
output_message = None
|
output_message = None
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
|
session=session,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
attachments=request.attachments or [],
|
attachments=request.attachments or [],
|
||||||
|
@ -170,6 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
|
session: Session,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
attachments: List[Attachment],
|
||||||
|
@ -190,7 +192,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
async for res in self._run(
|
async for res in self._run(
|
||||||
turn_id, input_messages, attachments, sampling_params, stream
|
turn_id, session, input_messages, attachments, sampling_params, stream
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
@ -275,32 +277,62 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _should_retrieve_context(
|
|
||||||
self, messages: List[Message], attachments: List[Attachment]
|
|
||||||
) -> bool:
|
|
||||||
return self.agent_config.memory_configs or len(attachments) > 0
|
|
||||||
|
|
||||||
async def _retrieve_context(
|
|
||||||
self, messages: List[Message], attachments: List[Attachment]
|
|
||||||
) -> List[Message]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self,
|
self,
|
||||||
|
session: Session,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
attachments: List[Attachment],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
need_context = await self._should_retrieve_context(input_messages, attachments)
|
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||||
if need_context:
|
need_rag_context = await self._should_retrieve_context(
|
||||||
context_messages = await self._retrieve_context(input_messages)
|
input_messages, attachments
|
||||||
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
)
|
||||||
# input_messages = input_messages + context
|
if need_rag_context:
|
||||||
input_messages = preprocess_dialog(input_messages)
|
step_id = str(uuid.uuid4())
|
||||||
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
|
event=AgenticSystemTurnResponseEvent(
|
||||||
|
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||||
|
step_type=StepType.memory_retrieval.value,
|
||||||
|
step_id=step_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
attachments = []
|
# TODO: find older context from the session and either replace it
|
||||||
|
# or append with a sliding window. this is really a very simplistic implementation
|
||||||
|
rag_context, bank_ids = await self._retrieve_context(input_messages)
|
||||||
|
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
|
event=AgenticSystemTurnResponseEvent(
|
||||||
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.memory_retrieval.value,
|
||||||
|
step_id=step_id,
|
||||||
|
step_details=MemoryRetrievalStep(
|
||||||
|
memory_bank_ids=bank_ids,
|
||||||
|
inserted_context=rag_context,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rag_context:
|
||||||
|
system_message = next(m for m in input_messages if m.role == "system")
|
||||||
|
if system_message:
|
||||||
|
system_message.content = system_message.content + "\n" + rag_context
|
||||||
|
else:
|
||||||
|
input_messages = [
|
||||||
|
Message(role="system", content=rag_context)
|
||||||
|
] + input_messages
|
||||||
|
|
||||||
|
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||||
|
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||||
|
input_messages.append(attachment_message(urls))
|
||||||
|
|
||||||
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
while True:
|
while True:
|
||||||
|
@ -414,7 +446,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if len(message.tool_calls) == 0:
|
if len(message.tool_calls) == 0:
|
||||||
if stop_reason == StopReason.end_of_turn:
|
if stop_reason == StopReason.end_of_turn:
|
||||||
if len(attachments) > 0:
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
|
if len(output_attachments) > 0:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
message.content += attachments
|
message.content += attachments
|
||||||
else:
|
else:
|
||||||
|
@ -526,58 +559,131 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
attachments.append(result_message.content)
|
output_attachments.append(result_message.content)
|
||||||
elif isinstance(result_message.content, list) or isinstance(
|
elif isinstance(result_message.content, list) or isinstance(
|
||||||
result_message.content, tuple
|
result_message.content, tuple
|
||||||
):
|
):
|
||||||
for c in result_message.content:
|
for c in result_message.content:
|
||||||
if isinstance(c, Attachment):
|
if isinstance(c, Attachment):
|
||||||
attachments.append(c)
|
output_attachments.append(c)
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
|
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
|
||||||
|
if session.memory_bank is None:
|
||||||
|
session.memory_bank = await self.memory_api.create_memory_bank(
|
||||||
|
name=f"memory_bank_{session.session_id}",
|
||||||
|
config=VectorMemoryBankConfig(
|
||||||
|
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def attachment_message(url: URL) -> ToolResponseMessage:
|
return session.memory_bank
|
||||||
uri = url.uri
|
|
||||||
assert uri.startswith("file://")
|
async def _should_retrieve_context(
|
||||||
filepath = uri[len("file://") :]
|
self, messages: List[Message], attachments: List[Attachment]
|
||||||
|
) -> bool:
|
||||||
|
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||||
|
if attachments:
|
||||||
|
if (
|
||||||
|
AgenticSystemTool.code_interpreter.value in enabled_tools
|
||||||
|
and self.agent_config.tool_choice == ToolChoice.required
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return attachments or AgenticSystemTool.memory.value in enabled_tools
|
||||||
|
|
||||||
|
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||||
|
for t in self.agent_config.tools:
|
||||||
|
if t.type == AgenticSystemTool.memory.value:
|
||||||
|
return t
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _retrieve_context(
|
||||||
|
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
||||||
|
) -> Optional[InterleavedTextMedia]:
|
||||||
|
bank_ids = []
|
||||||
|
|
||||||
|
memory = self._memory_tool_definition()
|
||||||
|
assert memory is not None, "Memory tool not configured"
|
||||||
|
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
||||||
|
|
||||||
|
if attachments:
|
||||||
|
bank = await self._ensure_memory_bank(session)
|
||||||
|
bank_ids.append(bank.bank_id)
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
MemoryBankDocument(
|
||||||
|
doc_id=str(uuid.uuid4()),
|
||||||
|
content=a.content,
|
||||||
|
mime_type=a.mime_type,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for a in attachments
|
||||||
|
]
|
||||||
|
await self.memory_api.insert_documents(bank_id, documents)
|
||||||
|
|
||||||
|
assert len(bank_ids) > 0, "No memory banks configured?"
|
||||||
|
|
||||||
|
query = " ".join(m.content for m in messages)
|
||||||
|
tasks = [
|
||||||
|
self.memory_api.query_documents(
|
||||||
|
bank_id=bank_id,
|
||||||
|
query=query,
|
||||||
|
params={
|
||||||
|
"max_chunks": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for bank_id in bank_ids
|
||||||
|
]
|
||||||
|
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||||
|
chunks = [c for r in results for c in r.chunks]
|
||||||
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
chunks, scores = zip(
|
||||||
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||||
|
)
|
||||||
|
if not chunks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokens = 0
|
||||||
|
picked = []
|
||||||
|
for c in chunks[: memory.max_chunks]:
|
||||||
|
tokens += c.token_count
|
||||||
|
if tokens > memory.max_tokens_in_context:
|
||||||
|
cprint(
|
||||||
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
picked.append(c)
|
||||||
|
|
||||||
|
return [
|
||||||
|
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
|
*picked,
|
||||||
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
|
||||||
|
content = []
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
uri = url.uri
|
||||||
|
assert uri.startswith("file://")
|
||||||
|
filepath = uri[len("file://") :]
|
||||||
|
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
||||||
|
|
||||||
return ToolResponseMessage(
|
return ToolResponseMessage(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
tool_name=BuiltinTool.code_interpreter,
|
||||||
content=f'# There is a file accessible to you at "{filepath}"',
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dialog(messages: List[Message]) -> List[Message]:
|
|
||||||
# remove system message since those are
|
|
||||||
"""
|
|
||||||
Preprocesses the dialog by removing the system message and
|
|
||||||
adding the system message to the beginning of the dialog.
|
|
||||||
"""
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
for m in messages:
|
|
||||||
if m.role == Role.system.value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# NOTE: the ideal behavior is to use `file_path = ...` but that
|
|
||||||
# means we need to have stateful execution of code which we currently
|
|
||||||
# do not have.
|
|
||||||
if isinstance(m.content, Attachment):
|
|
||||||
ret.append(attachment_message(m.content.url))
|
|
||||||
elif isinstance(m.content, list):
|
|
||||||
for c in m.content:
|
|
||||||
if isinstance(c, Attachment):
|
|
||||||
ret.append(attachment_message(c.url))
|
|
||||||
|
|
||||||
ret.append(m)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_call_maybe(
|
async def execute_tool_call_maybe(
|
||||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
||||||
) -> List[ToolResponseMessage]:
|
) -> List[ToolResponseMessage]:
|
||||||
|
|
|
@ -101,6 +101,11 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
completion_message_batch: List[CompletionMessage]
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
embeddings: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -114,6 +119,13 @@ class Inference(Protocol):
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/embeddings")
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/batch_completion")
|
@webmethod(route="/inference/batch_completion")
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -23,17 +23,6 @@ class MemoryBankDocument(BaseModel):
|
||||||
metadata: Dict[str, Any]
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
|
||||||
content: InterleavedTextMedia
|
|
||||||
token_count: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class QueryDocumentsResponse(BaseModel):
|
|
||||||
chunks: List[Chunk]
|
|
||||||
scores: List[float]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryBankType(Enum):
|
class MemoryBankType(Enum):
|
||||||
vector = "vector"
|
vector = "vector"
|
||||||
|
@ -45,6 +34,7 @@ class MemoryBankType(Enum):
|
||||||
class VectorMemoryBankConfig(BaseModel):
|
class VectorMemoryBankConfig(BaseModel):
|
||||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
|
chunk_size_in_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class KeyValueMemoryBankConfig(BaseModel):
|
class KeyValueMemoryBankConfig(BaseModel):
|
||||||
|
@ -70,18 +60,39 @@ MemoryBankConfig = Annotated[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
content: InterleavedTextMedia
|
||||||
|
token_count: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QueryDocumentsResponse(BaseModel):
|
||||||
|
chunks: List[Chunk]
|
||||||
|
scores: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QueryAPI(Protocol):
|
||||||
|
@webmethod(route="/query_documents")
|
||||||
|
def query_documents(
|
||||||
|
self,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryBank(BaseModel):
|
class MemoryBank(BaseModel):
|
||||||
bank_id: str
|
bank_id: str
|
||||||
name: str
|
name: str
|
||||||
config: MemoryBankConfig
|
config: MemoryBankConfig
|
||||||
# if there's a pre-existing store which obeys the MemoryBank REST interface
|
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
|
||||||
url: Optional[URL] = None
|
url: Optional[URL] = None
|
||||||
|
|
||||||
|
|
||||||
class Memory(Protocol):
|
class Memory(Protocol):
|
||||||
@webmethod(route="/memory_banks/create")
|
@webmethod(route="/memory_banks/create")
|
||||||
def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
config: MemoryBankConfig,
|
config: MemoryBankConfig,
|
||||||
|
@ -89,33 +100,35 @@ class Memory(Protocol):
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/list", method="GET")
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
def list_memory_banks(self) -> List[MemoryBank]: ...
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get")
|
@webmethod(route="/memory_banks/get")
|
||||||
def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
|
async def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop", method="DELETE")
|
@webmethod(route="/memory_banks/drop", method="DELETE")
|
||||||
def drop_memory_bank(
|
async def drop_memory_bank(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
) -> str: ...
|
) -> str: ...
|
||||||
|
|
||||||
|
# this will just block now until documents are inserted, but it should
|
||||||
|
# probably return a Job instance which can be polled for completion
|
||||||
@webmethod(route="/memory_bank/insert")
|
@webmethod(route="/memory_bank/insert")
|
||||||
def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/update")
|
@webmethod(route="/memory_bank/update")
|
||||||
def update_documents(
|
async def update_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/query")
|
@webmethod(route="/memory_bank/query")
|
||||||
def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
|
@ -123,14 +136,14 @@ class Memory(Protocol):
|
||||||
) -> QueryDocumentsResponse: ...
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/documents/get")
|
@webmethod(route="/memory_bank/documents/get")
|
||||||
def get_documents(
|
async def get_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_ids: List[str],
|
document_ids: List[str],
|
||||||
) -> List[MemoryBankDocument]: ...
|
) -> List[MemoryBankDocument]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/documents/delete")
|
@webmethod(route="/memory_bank/documents/delete")
|
||||||
def delete_documents(
|
async def delete_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_ids: List[str],
|
document_ids: List[str],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue