forked from phoenix-oss/llama-stack-mirror
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
78a481bb22
commit
1a7490470a
33 changed files with 1648 additions and 1345 deletions
|
@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
|
|||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
MEMORY_QUERY_TOOL = "rag_tool.query_context"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
MEMORY_GROUP = "builtin::memory"
|
||||
|
||||
|
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
agent_config: AgentConfig,
|
||||
tempdir: str,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
vector_io_api: VectorIO,
|
||||
persistence_store: KVStore,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.tempdir = tempdir
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||
if memory_tool_group is None:
|
||||
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||
|
||||
if MEMORY_GROUP in toolgroups and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
query_args = {
|
||||
"messages": [msg.content for msg in input_messages],
|
||||
**toolgroup_args.get(memory_tool_group, {}),
|
||||
}
|
||||
|
||||
args = toolgroup_args.get(MEMORY_GROUP, {})
|
||||
vector_db_ids = args.get("vector_db_ids", [])
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
if "memory_bank_ids" not in query_args:
|
||||
query_args["memory_bank_ids"] = []
|
||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||
vector_db_ids.append(session_info.memory_bank_id)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
args=query_args,
|
||||
result = await self.tool_runtime_api.rag_tool.query_context(
|
||||
content=concat_interleaved_content(
|
||||
[msg.content for msg in input_messages]
|
||||
),
|
||||
query_config=RAGQueryConfig(
|
||||
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=5,
|
||||
),
|
||||
vector_db_ids=vector_db_ids,
|
||||
)
|
||||
retrieved_context = result.content
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=result.content,
|
||||
content=retrieved_context or [],
|
||||
)
|
||||
],
|
||||
),
|
||||
|
@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("output", retrieved_context)
|
||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||
if result.error_code == 0:
|
||||
if retrieved_context:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
last_message.context = retrieved_context
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
|
||||
# TODO: the semantic for registration is definitely not "creation"
|
||||
# so we need to fix it if we expect the agent to create a new vector db
|
||||
# for each session
|
||||
await self.vector_io_api.register_vector_db(
|
||||
vector_db_id=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
|
@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def add_to_session_memory_bank(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
vector_db_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
|
@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for a in data
|
||||
]
|
||||
await self.memory_api.insert_documents(
|
||||
bank_id=bank_id,
|
||||
await self.tool_runtime_api.rag_tool.insert_documents(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
|
||||
|
@ -955,7 +970,7 @@ async def execute_tool_call_maybe(
|
|||
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
kwargs=dict(
|
||||
session_id=session_id,
|
||||
**tool_call_args,
|
||||
),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue