forked from phoenix-oss/llama-stack-mirror
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
78a481bb22
commit
1a7490470a
33 changed files with 1648 additions and 1345 deletions
|
@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
script = args["code"]
|
||||
script = kwargs["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
pieces = [res["process_status"]]
|
||||
|
|
|
@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
|
|||
|
||||
|
||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
impl = MemoryToolRuntimeImpl(
|
||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||
)
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,87 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryToolConfig(BaseModel):
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MemoryToolRuntimeConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
pass
|
||||
|
|
|
@ -5,68 +5,64 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
RAGQueryGenerator,
|
||||
RAGQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: RAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
if config.type == RAGQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, content, **kwargs)
|
||||
elif config.type == RAGQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: DefaultRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||
return interleaved_content_as_str(content, sep=config.separator)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: LLMRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {
|
||||
"messages": [
|
||||
message.model_dump() if isinstance(message, BaseModel) else message
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
messages = []
|
||||
if isinstance(content, list):
|
||||
messages = [interleaved_content_as_str(m) for m in content]
|
||||
else:
|
||||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
content = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
|
|
|
@ -10,20 +10,29 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -35,65 +44,79 @@ def make_random_string(length: int = 8):
|
|||
)
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryToolRuntimeConfig,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
):
|
||||
self.config = config
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="query_memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
await self.vector_io_api.insert_chunks(
|
||||
chunks=chunks,
|
||||
vector_db_id=vector_db_id,
|
||||
)
|
||||
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
return RAGQueryResult(content=None)
|
||||
|
||||
async def _retrieve_context(
|
||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
input_messages,
|
||||
query_config.query_generator_config,
|
||||
content,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
self.vector_io_api.query_chunks(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": self.config.max_chunks,
|
||||
"max_chunks": query_config.max_chunks,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
return RAGQueryResult(content=None)
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
|
@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: self.config.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > self.config.max_tokens_in_context:
|
||||
for c in chunks[: query_config.max_chunks]:
|
||||
metadata = c.metadata
|
||||
tokens += metadata["token_count"]
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f"id:{metadata['document_id']}; content:{c.content}",
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
content=[
|
||||
TextContentItem(
|
||||
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
*picked,
|
||||
TextContentItem(
|
||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
# encountering fatals.
|
||||
return [
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
ToolDef(
|
||||
name="rag_tool.query_context",
|
||||
description="Retrieve context from memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="rag_tool.insert_documents",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
return ToolInvocationResult(
|
||||
content=concat_interleaved_content(context), error_code=0
|
||||
raise RuntimeError(
|
||||
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue