llama-stack/llama_stack/providers/inline/tool_runtime/rag/memory.py
Ashwin Bharambe ce33d02443
fix(tools): do not index tools, only index toolgroups (#2261)
When registering a MCP endpoint, we cannot list tools (like we used to)
since the MCP endpoint may be behind an auth wall. Registration can
happen much sooner (via run.yaml).

Instead, we do listing only when the _user_ actually calls listing.
Furthermore, we cache the list in-memory in the server. Currently, the
cache is not invalidated -- we may want to periodically re-list for MCP
servers. Note that they must call `list_tools` before calling
`invoke_tool` -- we use this critically.

This will enable us to list MCP servers in run.yaml

## Test Plan

Existing tests, updated tests accordingly.
2025-05-25 13:27:52 -07:00

221 lines
7.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
import secrets
import string
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
chunks = []
for doc in documents:
content = await content_from_doc(doc)
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
doc.metadata,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config.query_generator_config,
content,
inference_api=self.inference_api,
)
tasks = [
self.vector_io_api.query_chunks(
vector_db_id=vector_db_id,
query=query,
params={
"max_chunks": query_config.max_chunks,
"mode": query_config.mode,
},
)
for vector_db_id in vector_db_ids
]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked: list[InterleavedContentItem] = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
)
)
return RAGQueryResult(
content=picked,
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
},
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return ListToolDefsResponse(
data=[
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
),
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
return ToolInvocationResult(
content=result.content,
metadata=result.metadata,
)