forked from phoenix-oss/llama-stack-mirror
# What does this PR do? his PR allows users to customize the template used for chunks when inserted into the context. Additionally, this enables metadata injection into the context of an LLM for RAG. This makes a naive and crude assumption that each chunk should include the metadata, this is obviously redundant when multiple chunks are returned from the same document. In order to remove any sort of duplication of chunks, we'd have to make much more significant changes so this is a reasonable first step that unblocks users requesting this enhancement in https://github.com/meta-llama/llama-stack/issues/1767. In the future, this can be extended to support citations. List of Changes: - `llama_stack/apis/tools/rag_tool.py` - Added `chunk_template` field in `RAGQueryConfig`. - Added `field_validator` to validate the `chunk_template` field in `RAGQueryConfig`. - Ensured the `chunk_template` field includes placeholders `{index}` and `{chunk.content}`. - Updated the `query` method to use the `chunk_template` for formatting chunk text content. - `llama_stack/providers/inline/tool_runtime/rag/memory.py` - Modified the `insert` method to pass `doc.metadata` for chunk creation. - Enhanced the `query` method to format results using `chunk_template` and exclude unnecessary metadata fields like `token_count`. - `llama_stack/providers/utils/memory/vector_store.py` - Updated `make_overlapped_chunks` to include metadata serialization and token count for both content and metadata. - Added error handling for metadata serialization issues. - `pyproject.toml` - Added `pydantic.field_validator` as a recognized `classmethod` decorator in the linting configuration. - `tests/integration/tool_runtime/test_rag_tool.py` - Refactored test assertions to separate `assert_valid_chunk_response` and `assert_valid_text_response`. - Added integration tests to validate `chunk_template` functionality with and without metadata inclusion. - Included a test case to ensure `chunk_template` validation errors are raised appropriately. - `tests/unit/rag/test_vector_store.py` - Added unit tests for `make_overlapped_chunks`, verifying chunk creation with overlapping tokens and metadata integrity. - Added tests to handle metadata serialization errors, ensuring proper exception handling. - `docs/_static/llama-stack-spec.html` - Added a new `chunk_template` field of type `string` with a default template for formatting retrieved chunks in RAGQueryConfig. - Updated the `required` fields to include `chunk_template`. - `docs/_static/llama-stack-spec.yaml` - Introduced `chunk_template` field with a default value for RAGQueryConfig. - Updated the required configuration list to include `chunk_template`. - `docs/source/building_applications/rag.md` - Documented the `chunk_template` configuration, explaining how to customize metadata formatting in RAG queries. - Added examples demonstrating the usage of the `chunk_template` field in RAG tool queries. - Highlighted default values for `RAG` agent configurations. # Resolves https://github.com/meta-llama/llama-stack/issues/1767 ## Test Plan Updated both `test_vector_store.py` and `test_rag_tool.py` and tested end-to-end with a script. I also tested the quickstart to enable this and specified this metadata: ```python document = RAGDocument( document_id="document_1", content=source, mime_type="text/html", metadata={"author": "Paul Graham", "title": "How to do great work"}, ) ``` Which produced the output below:  This highlights the usefulness of the additional metadata. Notice how the metadata is redundant for different chunks of the same document. I think we can update that in a subsequent PR. # Documentation I've added a brief comment about this in the documentation to outline this to users and updated the API documentation. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
220 lines
7.2 KiB
Python
220 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import logging
|
|
import secrets
|
|
import string
|
|
from typing import Any
|
|
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
URL,
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import Inference
|
|
from llama_stack.apis.tools import (
|
|
ListToolDefsResponse,
|
|
RAGDocument,
|
|
RAGQueryConfig,
|
|
RAGQueryResult,
|
|
RAGToolRuntime,
|
|
Tool,
|
|
ToolDef,
|
|
ToolInvocationResult,
|
|
ToolParameter,
|
|
ToolRuntime,
|
|
)
|
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
content_from_doc,
|
|
make_overlapped_chunks,
|
|
)
|
|
|
|
from .config import RagToolRuntimeConfig
|
|
from .context_retriever import generate_rag_query
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def make_random_string(length: int = 8):
|
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
|
|
|
|
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|
def __init__(
|
|
self,
|
|
config: RagToolRuntimeConfig,
|
|
vector_io_api: VectorIO,
|
|
inference_api: Inference,
|
|
):
|
|
self.config = config
|
|
self.vector_io_api = vector_io_api
|
|
self.inference_api = inference_api
|
|
|
|
async def initialize(self):
|
|
pass
|
|
|
|
async def shutdown(self):
|
|
pass
|
|
|
|
async def register_tool(self, tool: Tool) -> None:
|
|
pass
|
|
|
|
async def unregister_tool(self, tool_id: str) -> None:
|
|
return
|
|
|
|
async def insert(
|
|
self,
|
|
documents: list[RAGDocument],
|
|
vector_db_id: str,
|
|
chunk_size_in_tokens: int = 512,
|
|
) -> None:
|
|
chunks = []
|
|
for doc in documents:
|
|
content = await content_from_doc(doc)
|
|
chunks.extend(
|
|
make_overlapped_chunks(
|
|
doc.document_id,
|
|
content,
|
|
chunk_size_in_tokens,
|
|
chunk_size_in_tokens // 4,
|
|
doc.metadata,
|
|
)
|
|
)
|
|
|
|
if not chunks:
|
|
return
|
|
|
|
await self.vector_io_api.insert_chunks(
|
|
chunks=chunks,
|
|
vector_db_id=vector_db_id,
|
|
)
|
|
|
|
async def query(
|
|
self,
|
|
content: InterleavedContent,
|
|
vector_db_ids: list[str],
|
|
query_config: RAGQueryConfig | None = None,
|
|
) -> RAGQueryResult:
|
|
if not vector_db_ids:
|
|
raise ValueError(
|
|
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
|
|
)
|
|
|
|
query_config = query_config or RAGQueryConfig()
|
|
query = await generate_rag_query(
|
|
query_config.query_generator_config,
|
|
content,
|
|
inference_api=self.inference_api,
|
|
)
|
|
tasks = [
|
|
self.vector_io_api.query_chunks(
|
|
vector_db_id=vector_db_id,
|
|
query=query,
|
|
params={
|
|
"max_chunks": query_config.max_chunks,
|
|
},
|
|
)
|
|
for vector_db_id in vector_db_ids
|
|
]
|
|
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
|
chunks = [c for r in results for c in r.chunks]
|
|
scores = [s for r in results for s in r.scores]
|
|
|
|
if not chunks:
|
|
return RAGQueryResult(content=None)
|
|
|
|
# sort by score
|
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
|
chunks = chunks[: query_config.max_chunks]
|
|
|
|
tokens = 0
|
|
picked: list[InterleavedContentItem] = [
|
|
TextContentItem(
|
|
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
|
)
|
|
]
|
|
for i, chunk in enumerate(chunks):
|
|
metadata = chunk.metadata
|
|
tokens += metadata["token_count"]
|
|
tokens += metadata["metadata_token_count"]
|
|
|
|
if tokens > query_config.max_tokens_in_context:
|
|
log.error(
|
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
|
)
|
|
break
|
|
|
|
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
|
|
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
|
|
picked.append(TextContentItem(text=text_content))
|
|
|
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
|
picked.append(
|
|
TextContentItem(
|
|
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
|
)
|
|
)
|
|
|
|
return RAGQueryResult(
|
|
content=picked,
|
|
metadata={
|
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
|
},
|
|
)
|
|
|
|
async def list_runtime_tools(
|
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
|
) -> ListToolDefsResponse:
|
|
# Parameters are not listed since these methods are not yet invoked automatically
|
|
# by the LLM. The method is only implemented so things like /tools can list without
|
|
# encountering fatals.
|
|
return ListToolDefsResponse(
|
|
data=[
|
|
ToolDef(
|
|
name="insert_into_memory",
|
|
description="Insert documents into memory",
|
|
),
|
|
ToolDef(
|
|
name="knowledge_search",
|
|
description="Search for information in a database.",
|
|
parameters=[
|
|
ToolParameter(
|
|
name="query",
|
|
description="The query to search for. Can be a natural language sentence or keywords.",
|
|
parameter_type="string",
|
|
),
|
|
],
|
|
),
|
|
]
|
|
)
|
|
|
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
|
vector_db_ids = kwargs.get("vector_db_ids", [])
|
|
query_config = kwargs.get("query_config")
|
|
if query_config:
|
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
|
else:
|
|
# handle someone passing an empty dict
|
|
query_config = RAGQueryConfig()
|
|
|
|
query = kwargs["query"]
|
|
result = await self.query(
|
|
content=query,
|
|
vector_db_ids=vector_db_ids,
|
|
query_config=query_config,
|
|
)
|
|
|
|
return ToolInvocationResult(
|
|
content=result.content,
|
|
metadata=result.metadata,
|
|
)
|