# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio import logging import secrets import string from typing import Any from pydantic import TypeAdapter from llama_stack.apis.common.content_types import ( URL, InterleavedContent, InterleavedContentItem, TextContentItem, ) from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( ListToolDefsResponse, RAGDocument, RAGQueryConfig, RAGQueryResult, RAGToolRuntime, Tool, ToolDef, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, make_overlapped_chunks, ) from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query log = logging.getLogger(__name__) def make_random_string(length: int = 8): return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: RagToolRuntimeConfig, vector_io_api: VectorIO, inference_api: Inference, ): self.config = config self.vector_io_api = vector_io_api self.inference_api = inference_api async def initialize(self): pass async def shutdown(self): pass async def register_tool(self, tool: Tool) -> None: pass async def unregister_tool(self, tool_id: str) -> None: return async def insert( self, documents: list[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: chunks = [] for doc in documents: content = await content_from_doc(doc) chunks.extend( make_overlapped_chunks( doc.document_id, content, chunk_size_in_tokens, chunk_size_in_tokens // 4, doc.metadata, ) ) if not chunks: return await self.vector_io_api.insert_chunks( chunks=chunks, vector_db_id=vector_db_id, ) async def query( self, content: InterleavedContent, vector_db_ids: list[str], query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: if not vector_db_ids: raise ValueError( "No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID." ) query_config = query_config or RAGQueryConfig() query = await generate_rag_query( query_config.query_generator_config, content, inference_api=self.inference_api, ) tasks = [ self.vector_io_api.query_chunks( vector_db_id=vector_db_id, query=query, params={ "max_chunks": query_config.max_chunks, }, ) for vector_db_id in vector_db_ids ] results: list[QueryChunksResponse] = await asyncio.gather(*tasks) chunks = [c for r in results for c in r.chunks] scores = [s for r in results for s in r.scores] if not chunks: return RAGQueryResult(content=None) # sort by score chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore chunks = chunks[: query_config.max_chunks] tokens = 0 picked: list[InterleavedContentItem] = [ TextContentItem( text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" ) ] for i, chunk in enumerate(chunks): metadata = chunk.metadata tokens += metadata["token_count"] tokens += metadata["metadata_token_count"] if tokens > query_config.max_tokens_in_context: log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", ) break metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]} text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset) picked.append(TextContentItem(text=text_content)) picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) picked.append( TextContentItem( text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n', ) ) return RAGQueryResult( content=picked, metadata={ "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], }, ) async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: # Parameters are not listed since these methods are not yet invoked automatically # by the LLM. The method is only implemented so things like /tools can list without # encountering fatals. return ListToolDefsResponse( data=[ ToolDef( name="insert_into_memory", description="Insert documents into memory", ), ToolDef( name="knowledge_search", description="Search for information in a database.", parameters=[ ToolParameter( name="query", description="The query to search for. Can be a natural language sentence or keywords.", parameter_type="string", ), ], ), ] ) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: vector_db_ids = kwargs.get("vector_db_ids", []) query_config = kwargs.get("query_config") if query_config: query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) else: # handle someone passing an empty dict query_config = RAGQueryConfig() query = kwargs["query"] result = await self.query( content=query, vector_db_ids=vector_db_ids, query_config=query_config, ) return ToolInvocationResult( content=result.content, metadata=result.metadata, )