Introduce RAGToolRuntime as a specialized sub-protocol

This commit is contained in:
Ashwin Bharambe 2025-01-21 12:13:44 -08:00
parent 78a481bb22
commit 2f76de1643
16 changed files with 260 additions and 224 deletions

View file

@ -5,3 +5,4 @@
# the root directory of this source tree. # the root directory of this source tree.
from .tools import * # noqa: F401 F403 from .tools import * # noqa: F401 F403
from .rag_tool import * # noqa: F401 F403

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class RAGDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
@json_schema_type
class RAGQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
type: Literal["default"] = "default"
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
type: Literal["llm"] = "llm"
model: str
template: str
RAGQueryGeneratorConfig = register_schema(
Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"),
],
name="RAGQueryGeneratorConfig",
)
@json_schema_type
class RAGQueryConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(
default=DefaultRAGQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_ids: List[str],
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent"""
...

View file

@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from .rag_tool import RAGToolRuntime
@json_schema_type @json_schema_type
class ToolParameter(BaseModel): class ToolParameter(BaseModel):
@ -130,11 +132,17 @@ class ToolGroups(Protocol):
... ...
class SpecialToolGroups(Enum):
rag_tool = "rag_tool"
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore
rag_tool: RAGToolRuntime
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments"""
... ...

View file

@ -333,6 +333,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config impl.__provider_config__ = config
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api]) check_protocol_compliance(impl, protocols[provider_spec.api])
if ( if (
not isinstance(provider_spec, AutoRoutedProviderSpec) not isinstance(provider_spec, AutoRoutedProviderSpec)

View file

@ -406,6 +406,12 @@ class ToolRuntimeRouter(ToolRuntime):
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
# TODO: this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.routing_table.get_provider_impl("builtin::memory")
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass

View file

@ -9,6 +9,8 @@ from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.tools import SpecialToolGroups
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map from llama_stack.distribution.resolver import api_protocol_map
@ -29,6 +31,15 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
for api, protocol in protocols.items(): for api, protocol in protocols.items():
endpoints = [] endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
if api == Api.tool_runtime:
for tool_group in SpecialToolGroups:
print(f"tool_group: {tool_group}")
sub_protocol = getattr(protocol, tool_group.value)
sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction
)
for name, method in sub_protocol_methods:
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods: for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"): if not hasattr(method, "__webmethod__"):

View file

@ -19,9 +19,8 @@ async def get_provider_impl(
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
config, config,
deps[Api.inference], deps[Api.inference],
deps[Api.memory], deps[Api.vector_io],
deps[Api.safety], deps[Api.safety],
deps[Api.memory_banks],
deps[Api.tool_runtime], deps[Api.tool_runtime],
deps[Api.tool_groups], deps[Api.tool_groups],
) )

View file

@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.apis.memory import Memory, MemoryBankDocument
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import (
DefaultRAGQueryGeneratorConfig,
RAGDocument,
RAGQueryConfig,
ToolGroups,
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_memory" MEMORY_QUERY_TOOL = "rag_tool.query_context"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory" MEMORY_GROUP = "builtin::memory"
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
agent_config: AgentConfig, agent_config: AgentConfig,
tempdir: str, tempdir: str,
inference_api: Inference, inference_api: Inference,
memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety, safety_api: Safety,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
persistence_store: KVStore, persistence_store: KVStore,
): ):
self.agent_id = agent_id self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
self.tempdir = tempdir self.tempdir = tempdir
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api self.safety_api = safety_api
self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
@ -384,10 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
await self.handle_documents( await self.handle_documents(
session_id, documents, input_messages, tool_defs session_id, documents, input_messages, tool_defs
) )
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0: if "builtin::memory" in toolgroup_args and len(input_messages) > 0:
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
if memory_tool_group is None:
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
with tracing.span(MEMORY_QUERY_TOOL) as span: with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -398,17 +398,15 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(memory_tool_group, {}),
}
args = toolgroup_args["builtin::memory"]
vector_db_ids = args.get("vector_db_ids", [])
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id: if session_info.memory_bank_id:
if "memory_bank_ids" not in query_args: vector_db_ids.append(session_info.memory_bank_id)
query_args["memory_bank_ids"] = []
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
@ -425,9 +423,16 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
result = await self.tool_runtime_api.invoke_tool( retrieved_context = await self.tool_runtime_api.rag_tool.query_context(
tool_name=MEMORY_QUERY_TOOL, content=concat_interleaved_content(
args=query_args, [msg.content for msg in input_messages]
),
query_config=RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=5,
),
vector_db_ids=vector_db_ids,
) )
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -449,7 +454,7 @@ class ChatAgent(ShieldRunnerMixin):
ToolResponse( ToolResponse(
call_id="", call_id="",
tool_name=MEMORY_QUERY_TOOL, tool_name=MEMORY_QUERY_TOOL,
content=result.content, content=retrieved_context or [],
) )
], ],
), ),
@ -459,13 +464,11 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute( span.set_attribute(
"input", [m.model_dump_json() for m in input_messages] "input", [m.model_dump_json() for m in input_messages]
) )
span.set_attribute("output", result.content) span.set_attribute("output", retrieved_context)
span.set_attribute("error_code", result.error_code) span.set_attribute("tool_name", "builtin::memory")
span.set_attribute("error_message", result.error_message) if retrieved_context:
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
if result.error_code == 0:
last_message = input_messages[-1] last_message = input_messages[-1]
last_message.context = result.content last_message.context = retrieved_context
output_attachments = [] output_attachments = []
@ -842,12 +845,13 @@ class ChatAgent(ShieldRunnerMixin):
if session_info.memory_bank_id is None: if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}" bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id, # TODO: the semantic for registration is definitely not "creation"
params=VectorMemoryBankParams( # so we need to fix it if we expect the agent to create a new vector db
embedding_model="all-MiniLM-L6-v2", # for each session
chunk_size_in_tokens=512, await self.vector_io_api.register_vector_db(
), vector_db_id=bank_id,
embedding_model="all-MiniLM-L6-v2",
) )
await self.storage.add_memory_bank_to_session(session_id, bank_id) await self.storage.add_memory_bank_to_session(session_id, bank_id)
else: else:
@ -860,7 +864,7 @@ class ChatAgent(ShieldRunnerMixin):
) -> None: ) -> None:
bank_id = await self._ensure_memory_bank(session_id) bank_id = await self._ensure_memory_bank(session_id)
documents = [ documents = [
MemoryBankDocument( RAGDocument(
document_id=str(uuid.uuid4()), document_id=str(uuid.uuid4()),
content=a.content, content=a.content,
mime_type=a.mime_type, mime_type=a.mime_type,
@ -868,9 +872,10 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in data for a in data
] ]
await self.memory_api.insert_documents( await self.tool_runtime_api.rag_tool.insert_documents(
bank_id=bank_id,
documents=documents, documents=documents,
vector_db_ids=[bank_id],
chunk_size_in_tokens=512,
) )

View file

@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
Turn, Turn,
) )
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
config: MetaReferenceAgentsImplConfig, config: MetaReferenceAgentsImplConfig,
inference_api: Inference, inference_api: Inference,
memory_api: Memory, vector_io_api: VectorIO,
safety_api: Safety, safety_api: Safety,
memory_banks_api: MemoryBanks,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.vector_io_api = vector_io_api
self.safety_api = safety_api self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
tempdir=self.tempdir, tempdir=self.tempdir,
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
memory_api=self.memory_api, vector_io_api=self.vector_io_api,
memory_banks_api=self.memory_banks_api,
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
persistence_store=( persistence_store=(

View file

@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl( impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,87 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from pydantic import BaseModel
from typing import Annotated, List, Literal, Union
from pydantic import BaseModel, Field
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
class MemoryToolRuntimeConfig(BaseModel): class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages pass
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5

View file

@ -47,7 +47,7 @@ async def default_rag_query_generator(
messages: List[InterleavedContent], messages: List[InterleavedContent],
**kwargs, **kwargs,
): ):
return config.sep.join(interleaved_content_as_str(m) for m in messages) return config.separator.join(interleaved_content_as_str(m) for m in messages)
async def llm_rag_query_generator( async def llm_rag_query_generator(

View file

@ -10,20 +10,25 @@ import secrets
import string import string
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import (
from llama_stack.apis.inference import Inference, InterleavedContent InterleavedContent,
from llama_stack.apis.memory import Memory, QueryDocumentsResponse TextContentItem,
from llama_stack.apis.memory_banks import MemoryBanks URL,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from .config import MemoryToolConfig, MemoryToolRuntimeConfig from .config import MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,65 +40,61 @@ def make_random_string(length: int = 8):
) )
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
config: MemoryToolRuntimeConfig, config: MemoryToolRuntimeConfig,
memory_api: Memory, vector_io_api: VectorIO,
memory_banks_api: MemoryBanks,
inference_api: Inference, inference_api: Inference,
): ):
self.config = config self.config = config
self.memory_api = memory_api self.vector_io_api = vector_io_api
self.memory_banks_api = memory_banks_api
self.inference_api = inference_api self.inference_api = inference_api
async def initialize(self): async def initialize(self):
pass pass
async def list_runtime_tools( async def shutdown(self):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None pass
) -> List[ToolDef]:
return [ async def insert_documents(
ToolDef( self,
name="query_memory", documents: List[RAGDocument],
description="Retrieve context from memory", vector_db_ids: List[str],
parameters=[ chunk_size_in_tokens: int = 512,
ToolParameter( ) -> None:
name="messages", pass
description="The input messages to search for",
parameter_type="array", async def query_context(
), self,
], content: InterleavedContent,
) query_config: RAGQueryConfig,
] vector_db_ids: List[str],
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
async def _retrieve_context(
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query( query = await generate_rag_query(
self.config.query_generator_config, query_config.query_generator_config,
input_messages, content,
inference_api=self.inference_api, inference_api=self.inference_api,
) )
tasks = [ tasks = [
self.memory_api.query_documents( self.vector_io_api.query_chunks(
bank_id=bank_id, vector_db_id=vector_db_id,
query=query, query=query,
params={ params={
"max_chunks": self.config.max_chunks, "max_chunks": query_config.max_chunks,
}, },
) )
for bank_id in bank_ids for vector_db_id in vector_db_ids
] ]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks] chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores] scores = [s for r in results for s in r.scores]
if not chunks: if not chunks:
return None return RAGQueryResult(content=None)
# sort by score # sort by score
chunks, scores = zip( chunks, scores = zip(
@ -102,45 +103,47 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
tokens = 0 tokens = 0
picked = [] picked = []
for c in chunks[: self.config.max_chunks]: for c in chunks[: query_config.max_chunks]:
tokens += c.token_count tokens += c.token_count
if tokens > self.config.max_tokens_in_context: if tokens > query_config.max_tokens_in_context:
log.error( log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
) )
break break
picked.append(f"id:{c.document_id}; content:{c.content}") picked.append(f"id:{c.document_id}; content:{c.content}")
return RAGQueryResult(
content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*picked,
TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [ return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", ToolDef(
*picked, name="rag_tool.query_context",
"\n=== END-RETRIEVED-CONTEXT ===\n", description="Retrieve context from memory",
),
ToolDef(
name="rag_tool.insert_documents",
description="Insert documents into memory",
),
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) raise RuntimeError(
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id) "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
final_args["messages"],
bank_ids,
)
if context is None:
context = []
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
) )

View file

@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory", module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference], api_dependencies=[Api.vector_io, Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,

View file

@ -8,13 +8,12 @@ import uuid
import pytest import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
from llama_stack.apis.vector_io import QueryChunksResponse from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
make_overlapped_chunks,
MemoryBankDocument,
)
# How to run this test: # How to run this test:
# #
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sample_chunks(): def sample_chunks():
docs = [ docs = [
MemoryBankDocument( RAGDocument(
document_id="doc1", document_id="doc1",
content="Python is a high-level programming language.", content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"}, metadata={"category": "programming", "difficulty": "beginner"},
), ),
MemoryBankDocument( RAGDocument(
document_id="doc2", document_id="doc2",
content="Machine learning is a subset of artificial intelligence.", content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"}, metadata={"category": "AI", "difficulty": "advanced"},
), ),
MemoryBankDocument( RAGDocument(
document_id="doc3", document_id="doc3",
content="Data structures are fundamental to computer science.", content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"}, metadata={"category": "computer science", "difficulty": "intermediate"},
), ),
MemoryBankDocument( RAGDocument(
document_id="doc4", document_id="doc4",
content="Neural networks are inspired by biological neural networks.", content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"}, metadata={"category": "AI", "difficulty": "advanced"},

View file

@ -19,7 +19,6 @@ import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray from numpy.typing import NDArray
from pydantic import BaseModel, Field
from pypdf import PdfReader from pypdf import PdfReader
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
TextContentItem, TextContentItem,
URL, URL,
) )
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data) pdf_bytes = io.BytesIO(data)
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
return ret return ret
async def content_from_doc(doc: MemoryBankDocument) -> str: async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL): if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"): if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri) return content_from_data(doc.content.uri)