[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:04:16 -08:00 committed by GitHub
parent 78a481bb22
commit 1a7490470a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1648 additions and 1345 deletions

View file

@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
script = args["code"]
script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]

View file

@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl(
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
await impl.initialize()
return impl

View file

@ -4,87 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, List, Literal, Union
from pydantic import BaseModel, Field
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
from pydantic import BaseModel
class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5
pass

View file

@ -5,68 +5,64 @@
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: RAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs)
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: DefaultRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return config.sep.join(interleaved_content_as_str(m) for m in messages)
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: LLMRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
messages = []
if isinstance(content, list):
messages = [interleaved_content_as_str(m) for m in content]
else:
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
content = template.render(m_dict)
content = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)

View file

@ -10,20 +10,29 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
from .config import MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
@ -35,65 +44,79 @@ def make_random_string(length: int = 8):
)
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: MemoryToolRuntimeConfig,
memory_api: Memory,
memory_banks_api: MemoryBanks,
vector_io_api: VectorIO,
inference_api: Inference,
):
self.config = config
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.vector_io_api = vector_io_api
self.inference_api = inference_api
async def initialize(self):
pass
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="query_memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def shutdown(self):
pass
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
chunks = []
for doc in documents:
content = await content_from_doc(doc)
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
async def _retrieve_context(
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query(
self.config.query_generator_config,
input_messages,
query_config.query_generator_config,
content,
inference_api=self.inference_api,
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
self.vector_io_api.query_chunks(
vector_db_id=vector_db_id,
query=query,
params={
"max_chunks": self.config.max_chunks,
"max_chunks": query_config.max_chunks,
},
)
for bank_id in bank_ids
for vector_db_id in vector_db_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(
@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
tokens = 0
picked = []
for c in chunks[: self.config.max_chunks]:
tokens += c.token_count
if tokens > self.config.max_tokens_in_context:
for c in chunks[: query_config.max_chunks]:
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
)
)
return RAGQueryResult(
content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*picked,
TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
ToolDef(
name="rag_tool.query_context",
description="Retrieve context from memory",
),
ToolDef(
name="rag_tool.insert_documents",
description="Insert documents into memory",
),
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
final_args["messages"],
bank_ids,
)
if context is None:
context = []
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
)