From 2f76de16430ecfc31203a6675dc52a0600ec467d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 21 Jan 2025 12:13:44 -0800 Subject: [PATCH] Introduce RAGToolRuntime as a specialized sub-protocol --- llama_stack/apis/tools/__init__.py | 1 + llama_stack/apis/tools/rag_tool.py | 95 ++++++++++++ llama_stack/apis/tools/tools.py | 10 +- llama_stack/distribution/resolver.py | 2 + llama_stack/distribution/routers/routers.py | 6 + llama_stack/distribution/server/endpoints.py | 11 ++ .../inline/agents/meta_reference/__init__.py | 3 +- .../agents/meta_reference/agent_instance.py | 83 +++++----- .../inline/agents/meta_reference/agents.py | 12 +- .../inline/tool_runtime/memory/__init__.py | 4 +- .../inline/tool_runtime/memory/config.py | 83 +--------- .../tool_runtime/memory/context_retriever.py | 2 +- .../inline/tool_runtime/memory/memory.py | 143 +++++++++--------- .../providers/registry/tool_runtime.py | 2 +- .../tests/vector_io/test_vector_io.py | 15 +- .../providers/utils/memory/vector_store.py | 12 +- 16 files changed, 260 insertions(+), 224 deletions(-) create mode 100644 llama_stack/apis/tools/rag_tool.py diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py index f747fcdc2..8cd798ebf 100644 --- a/llama_stack/apis/tools/__init__.py +++ b/llama_stack/apis/tools/__init__.py @@ -5,3 +5,4 @@ # the root directory of this source tree. from .tools import * # noqa: F401 F403 +from .rag_tool import * # noqa: F401 F403 diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py new file mode 100644 index 000000000..d8e085410 --- /dev/null +++ b/llama_stack/apis/tools/rag_tool.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from llama_models.schema_utils import json_schema_type, register_schema, webmethod +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Protocol, runtime_checkable + +from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol + + +@json_schema_type +class RAGDocument(BaseModel): + document_id: str + content: InterleavedContent | URL + mime_type: str | None = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + +@json_schema_type +class RAGQueryResult(BaseModel): + content: Optional[InterleavedContent] = None + + +@json_schema_type +class RAGQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +@json_schema_type +class DefaultRAGQueryGeneratorConfig(BaseModel): + type: Literal["default"] = "default" + separator: str = " " + + +@json_schema_type +class LLMRAGQueryGeneratorConfig(BaseModel): + type: Literal["llm"] = "llm" + model: str + template: str + + +RAGQueryGeneratorConfig = register_schema( + Annotated[ + Union[ + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, + ], + Field(discriminator="type"), + ], + name="RAGQueryGeneratorConfig", +) + + +@json_schema_type +class RAGQueryConfig(BaseModel): + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: RAGQueryGeneratorConfig = Field( + default=DefaultRAGQueryGeneratorConfig() + ) + max_tokens_in_context: int = 4096 + max_chunks: int = 5 + + +@runtime_checkable +@trace_protocol +class RAGToolRuntime(Protocol): + @webmethod(route="/tool-runtime/rag-tool/insert", method="POST") + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_ids: List[str], + chunk_size_in_tokens: int = 512, + ) -> None: + """Index documents so they can be used by the RAG system""" + ... + + @webmethod(route="/tool-runtime/rag-tool/query", method="POST") + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + """Query the RAG system for context; typically invoked by the agent""" + ... diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index fb990cc41..65e54b40d 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from .rag_tool import RAGToolRuntime + @json_schema_type class ToolParameter(BaseModel): @@ -130,11 +132,17 @@ class ToolGroups(Protocol): ... +class SpecialToolGroups(Enum): + rag_tool = "rag_tool" + + @runtime_checkable @trace_protocol class ToolRuntime(Protocol): tool_store: ToolStore + rag_tool: RAGToolRuntime + # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") async def list_runtime_tools( @@ -143,7 +151,7 @@ class ToolRuntime(Protocol): @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index bd5a9ae98..dd6d4be6f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -333,6 +333,8 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config + # TODO: check compliance for special tool groups + # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol check_protocol_compliance(impl, protocols[provider_spec.api]) if ( not isinstance(provider_spec, AutoRoutedProviderSpec) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 979c68b72..0eb1a33e4 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -406,6 +406,12 @@ class ToolRuntimeRouter(ToolRuntime): ) -> None: self.routing_table = routing_table + # TODO: this should be in sync with "get_all_api_endpoints()" + # TODO: make sure rag_tool vs builtin::memory is correct everywhere + self.rag_tool = self.routing_table.get_provider_impl("builtin::memory") + setattr(self, "rag_tool.query_context", self.rag_tool.query_context) + setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents) + async def initialize(self) -> None: pass diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index af429e020..1033eaa23 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -9,6 +9,8 @@ from typing import Dict, List from pydantic import BaseModel +from llama_stack.apis.tools import SpecialToolGroups + from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.distribution.resolver import api_protocol_map @@ -29,6 +31,15 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: for api, protocol in protocols.items(): endpoints = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + if api == Api.tool_runtime: + for tool_group in SpecialToolGroups: + print(f"tool_group: {tool_group}") + sub_protocol = getattr(protocol, tool_group.value) + sub_protocol_methods = inspect.getmembers( + sub_protocol, predicate=inspect.isfunction + ) + for name, method in sub_protocol_methods: + protocol_methods.append((f"{tool_group.value}.{name}", method)) for name, method in protocol_methods: if not hasattr(method, "__webmethod__"): diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 50f61fb42..de34b8d2c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -19,9 +19,8 @@ async def get_provider_impl( impl = MetaReferenceAgentsImpl( config, deps[Api.inference], - deps[Api.memory], + deps[Api.vector_io], deps[Api.safety], - deps[Api.memory_banks], deps[Api.tool_runtime], deps[Api.tool_groups], ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2ebc7ded1..0f076ef87 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -59,13 +59,18 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) -from llama_stack.apis.memory import Memory, MemoryBankDocument -from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools import ( + DefaultRAGQueryGeneratorConfig, + RAGDocument, + RAGQueryConfig, + ToolGroups, + ToolRuntime, +) +from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing - from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin @@ -79,7 +84,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") -MEMORY_QUERY_TOOL = "query_memory" +MEMORY_QUERY_TOOL = "rag_tool.query_context" WEB_SEARCH_TOOL = "web_search" MEMORY_GROUP = "builtin::memory" @@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin): agent_config: AgentConfig, tempdir: str, inference_api: Inference, - memory_api: Memory, - memory_banks_api: MemoryBanks, safety_api: Safety, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, + vector_io_api: VectorIO, persistence_store: KVStore, ): self.agent_id = agent_id self.agent_config = agent_config self.tempdir = tempdir self.inference_api = inference_api - self.memory_api = memory_api - self.memory_banks_api = memory_banks_api self.safety_api = safety_api + self.vector_io_api = vector_io_api self.storage = AgentPersistence(agent_id, persistence_store) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -384,10 +387,7 @@ class ChatAgent(ShieldRunnerMixin): await self.handle_documents( session_id, documents, input_messages, tool_defs ) - if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0: - memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None) - if memory_tool_group is None: - raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}") + if "builtin::memory" in toolgroup_args and len(input_messages) > 0: with tracing.span(MEMORY_QUERY_TOOL) as span: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -398,17 +398,15 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - query_args = { - "messages": [msg.content for msg in input_messages], - **toolgroup_args.get(memory_tool_group, {}), - } + args = toolgroup_args["builtin::memory"] + vector_db_ids = args.get("vector_db_ids", []) session_info = await self.storage.get_session_info(session_id) + # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - if "memory_bank_ids" not in query_args: - query_args["memory_bank_ids"] = [] - query_args["memory_bank_ids"].append(session_info.memory_bank_id) + vector_db_ids.append(session_info.memory_bank_id) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -425,9 +423,16 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - result = await self.tool_runtime_api.invoke_tool( - tool_name=MEMORY_QUERY_TOOL, - args=query_args, + retrieved_context = await self.tool_runtime_api.rag_tool.query_context( + content=concat_interleaved_content( + [msg.content for msg in input_messages] + ), + query_config=RAGQueryConfig( + query_generator_config=DefaultRAGQueryGeneratorConfig(), + max_tokens_in_context=4096, + max_chunks=5, + ), + vector_db_ids=vector_db_ids, ) yield AgentTurnResponseStreamChunk( @@ -449,7 +454,7 @@ class ChatAgent(ShieldRunnerMixin): ToolResponse( call_id="", tool_name=MEMORY_QUERY_TOOL, - content=result.content, + content=retrieved_context or [], ) ], ), @@ -459,13 +464,11 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute( "input", [m.model_dump_json() for m in input_messages] ) - span.set_attribute("output", result.content) - span.set_attribute("error_code", result.error_code) - span.set_attribute("error_message", result.error_message) - span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - if result.error_code == 0: + span.set_attribute("output", retrieved_context) + span.set_attribute("tool_name", "builtin::memory") + if retrieved_context: last_message = input_messages[-1] - last_message.context = result.content + last_message.context = retrieved_context output_attachments = [] @@ -842,12 +845,13 @@ class ChatAgent(ShieldRunnerMixin): if session_info.memory_bank_id is None: bank_id = f"memory_bank_{session_id}" - await self.memory_banks_api.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), + + # TODO: the semantic for registration is definitely not "creation" + # so we need to fix it if we expect the agent to create a new vector db + # for each session + await self.vector_io_api.register_vector_db( + vector_db_id=bank_id, + embedding_model="all-MiniLM-L6-v2", ) await self.storage.add_memory_bank_to_session(session_id, bank_id) else: @@ -860,7 +864,7 @@ class ChatAgent(ShieldRunnerMixin): ) -> None: bank_id = await self._ensure_memory_bank(session_id) documents = [ - MemoryBankDocument( + RAGDocument( document_id=str(uuid.uuid4()), content=a.content, mime_type=a.mime_type, @@ -868,9 +872,10 @@ class ChatAgent(ShieldRunnerMixin): ) for a in data ] - await self.memory_api.insert_documents( - bank_id=bank_id, + await self.tool_runtime_api.rag_tool.insert_documents( documents=documents, + vector_db_ids=[bank_id], + chunk_size_in_tokens=512, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index d22ef82ab..b1844f4d0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -26,10 +26,9 @@ from llama_stack.apis.agents import ( Turn, ) from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from .agent_instance import ChatAgent @@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents): self, config: MetaReferenceAgentsImplConfig, inference_api: Inference, - memory_api: Memory, + vector_io_api: VectorIO, safety_api: Safety, - memory_banks_api: MemoryBanks, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, ): self.config = config self.inference_api = inference_api - self.memory_api = memory_api + self.vector_io_api = vector_io_api self.safety_api = safety_api - self.memory_banks_api = memory_banks_api self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents): tempdir=self.tempdir, inference_api=self.inference_api, safety_api=self.safety_api, - memory_api=self.memory_api, - memory_banks_api=self.memory_banks_api, + vector_io_api=self.vector_io_api, tool_runtime_api=self.tool_runtime_api, tool_groups_api=self.tool_groups_api, persistence_store=( diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py index 928afa484..42a0a6b01 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py @@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): - impl = MemoryToolRuntimeImpl( - config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] - ) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py index 6ff242c6b..4a20c986c 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/config.py +++ b/llama_stack/providers/inline/tool_runtime/memory/config.py @@ -4,87 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import Annotated, List, Literal, Union - -from pydantic import BaseModel, Field - - -class _MemoryBankConfigCommon(BaseModel): - bank_id: str - - -class VectorMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["vector"] = "vector" - - -class KeyValueMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyvalue"] = "keyvalue" - keys: List[str] # what keys to focus on - - -class KeywordMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyword"] = "keyword" - - -class GraphMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["graph"] = "graph" - entities: List[str] # what entities to focus on - - -MemoryBankConfig = Annotated[ - Union[ - VectorMemoryBankConfig, - KeyValueMemoryBankConfig, - KeywordMemoryBankConfig, - GraphMemoryBankConfig, - ], - Field(discriminator="type"), -] - - -class MemoryQueryGenerator(Enum): - default = "default" - llm = "llm" - custom = "custom" - - -class DefaultMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.default.value] = ( - MemoryQueryGenerator.default.value - ) - sep: str = " " - - -class LLMMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value - model: str - template: str - - -class CustomMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value - - -MemoryQueryGeneratorConfig = Annotated[ - Union[ - DefaultMemoryQueryGeneratorConfig, - LLMMemoryQueryGeneratorConfig, - CustomMemoryQueryGeneratorConfig, - ], - Field(discriminator="type"), -] - - -class MemoryToolConfig(BaseModel): - memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) +from pydantic import BaseModel class MemoryToolRuntimeConfig(BaseModel): - # This config defines how a query is generated using the messages - # for memory bank retrieval. - query_generator_config: MemoryQueryGeneratorConfig = Field( - default=DefaultMemoryQueryGeneratorConfig() - ) - max_tokens_in_context: int = 4096 - max_chunks: int = 5 + pass diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 803981f07..6684a6aa8 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -47,7 +47,7 @@ async def default_rag_query_generator( messages: List[InterleavedContent], **kwargs, ): - return config.sep.join(interleaved_content_as_str(m) for m in messages) + return config.separator.join(interleaved_content_as_str(m) for m in messages) async def llm_rag_query_generator( diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index fe6325abb..a2eeefa02 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -10,20 +10,25 @@ import secrets import string from typing import Any, Dict, List, Optional -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.memory import Memory, QueryDocumentsResponse -from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) +from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, ToolDef, ToolInvocationResult, - ToolParameter, ToolRuntime, ) +from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import ToolsProtocolPrivate -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content -from .config import MemoryToolConfig, MemoryToolRuntimeConfig +from .config import MemoryToolRuntimeConfig from .context_retriever import generate_rag_query log = logging.getLogger(__name__) @@ -35,65 +40,61 @@ def make_random_string(length: int = 8): ) -class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): +class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: MemoryToolRuntimeConfig, - memory_api: Memory, - memory_banks_api: MemoryBanks, + vector_io_api: VectorIO, inference_api: Inference, ): self.config = config - self.memory_api = memory_api - self.memory_banks_api = memory_banks_api + self.vector_io_api = vector_io_api self.inference_api = inference_api async def initialize(self): pass - async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None - ) -> List[ToolDef]: - return [ - ToolDef( - name="query_memory", - description="Retrieve context from memory", - parameters=[ - ToolParameter( - name="messages", - description="The input messages to search for", - parameter_type="array", - ), - ], - ) - ] + async def shutdown(self): + pass + + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_ids: List[str], + chunk_size_in_tokens: int = 512, + ) -> None: + pass + + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + if not vector_db_ids: + return RAGQueryResult(content=None) - async def _retrieve_context( - self, input_messages: List[InterleavedContent], bank_ids: List[str] - ) -> Optional[List[InterleavedContent]]: - if not bank_ids: - return None query = await generate_rag_query( - self.config.query_generator_config, - input_messages, + query_config.query_generator_config, + content, inference_api=self.inference_api, ) tasks = [ - self.memory_api.query_documents( - bank_id=bank_id, + self.vector_io_api.query_chunks( + vector_db_id=vector_db_id, query=query, params={ - "max_chunks": self.config.max_chunks, + "max_chunks": query_config.max_chunks, }, ) - for bank_id in bank_ids + for vector_db_id in vector_db_ids ] - results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) + results: List[QueryChunksResponse] = await asyncio.gather(*tasks) chunks = [c for r in results for c in r.chunks] scores = [s for r in results for s in r.scores] if not chunks: - return None + return RAGQueryResult(content=None) # sort by score chunks, scores = zip( @@ -102,45 +103,47 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): tokens = 0 picked = [] - for c in chunks[: self.config.max_chunks]: + for c in chunks[: query_config.max_chunks]: tokens += c.token_count - if tokens > self.config.max_tokens_in_context: + if tokens > query_config.max_tokens_in_context: log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", ) break picked.append(f"id:{c.document_id}; content:{c.content}") + return RAGQueryResult( + content=[ + TextContentItem( + text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", + ), + *picked, + TextContentItem( + text="\n=== END-RETRIEVED-CONTEXT ===\n", + ), + ], + ) + + async def list_runtime_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + # Parameters are not listed since these methods are not yet invoked automatically + # by the LLM. The method is only implemented so things like /tools can list without + # encountering fatals. return [ - "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - *picked, - "\n=== END-RETRIEVED-CONTEXT ===\n", + ToolDef( + name="rag_tool.query_context", + description="Retrieve context from memory", + ), + ToolDef( + name="rag_tool.insert_documents", + description="Insert documents into memory", + ), ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: - tool = await self.tool_store.get_tool(tool_name) - tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id) - final_args = tool_group.args or {} - final_args.update(args) - config = MemoryToolConfig() - if tool.metadata and tool.metadata.get("config") is not None: - config = MemoryToolConfig(**tool.metadata["config"]) - if "memory_bank_ids" in final_args: - bank_ids = final_args["memory_bank_ids"] - else: - bank_ids = [ - bank_config.bank_id for bank_config in config.memory_bank_configs - ] - if "messages" not in final_args: - raise ValueError("messages are required") - context = await self._retrieve_context( - final_args["messages"], - bank_ids, - ) - if context is None: - context = [] - return ToolInvocationResult( - content=concat_interleaved_content(context), error_code=0 + raise RuntimeError( + "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol" ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b3ea68949..426fe22f2 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[], module="llama_stack.providers.inline.tool_runtime.memory", config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference], + api_dependencies=[Api.vector_io, Api.inference], ), InlineProviderSpec( api=Api.tool_runtime, diff --git a/llama_stack/providers/tests/vector_io/test_vector_io.py b/llama_stack/providers/tests/vector_io/test_vector_io.py index 901b8bd11..521131f63 100644 --- a/llama_stack/providers/tests/vector_io/test_vector_io.py +++ b/llama_stack/providers/tests/vector_io/test_vector_io.py @@ -8,13 +8,12 @@ import uuid import pytest +from llama_stack.apis.tools import RAGDocument + from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB from llama_stack.apis.vector_io import QueryChunksResponse -from llama_stack.providers.utils.memory.vector_store import ( - make_overlapped_chunks, - MemoryBankDocument, -) +from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks # How to run this test: # @@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import ( @pytest.fixture(scope="session") def sample_chunks(): docs = [ - MemoryBankDocument( + RAGDocument( document_id="doc1", content="Python is a high-level programming language.", metadata={"category": "programming", "difficulty": "beginner"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc2", content="Machine learning is a subset of artificial intelligence.", metadata={"category": "AI", "difficulty": "advanced"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc3", content="Data structures are fundamental to computer science.", metadata={"category": "computer science", "difficulty": "intermediate"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc4", content="Neural networks are inspired by biological neural networks.", metadata={"category": "AI", "difficulty": "advanced"}, diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index c2de6c714..c31ee57d8 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -19,7 +19,6 @@ import numpy as np from llama_models.llama3.api.tokenizer import Tokenizer from numpy.typing import NDArray -from pydantic import BaseModel, Field from pypdf import PdfReader from llama_stack.apis.common.content_types import ( @@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import ( TextContentItem, URL, ) +from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.providers.datatypes import Api @@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) - log = logging.getLogger(__name__) -class MemoryBankDocument(BaseModel): - document_id: str - content: InterleavedContent | URL - mime_type: str | None = None - metadata: Dict[str, Any] = Field(default_factory=dict) - - def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string pdf_bytes = io.BytesIO(data) @@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved return ret -async def content_from_doc(doc: MemoryBankDocument) -> str: +async def content_from_doc(doc: RAGDocument) -> str: if isinstance(doc.content, URL): if doc.content.uri.startswith("data:"): return content_from_data(doc.content.uri)