forked from phoenix-oss/llama-stack-mirror
# What does this PR do? his PR allows users to customize the template used for chunks when inserted into the context. Additionally, this enables metadata injection into the context of an LLM for RAG. This makes a naive and crude assumption that each chunk should include the metadata, this is obviously redundant when multiple chunks are returned from the same document. In order to remove any sort of duplication of chunks, we'd have to make much more significant changes so this is a reasonable first step that unblocks users requesting this enhancement in https://github.com/meta-llama/llama-stack/issues/1767. In the future, this can be extended to support citations. List of Changes: - `llama_stack/apis/tools/rag_tool.py` - Added `chunk_template` field in `RAGQueryConfig`. - Added `field_validator` to validate the `chunk_template` field in `RAGQueryConfig`. - Ensured the `chunk_template` field includes placeholders `{index}` and `{chunk.content}`. - Updated the `query` method to use the `chunk_template` for formatting chunk text content. - `llama_stack/providers/inline/tool_runtime/rag/memory.py` - Modified the `insert` method to pass `doc.metadata` for chunk creation. - Enhanced the `query` method to format results using `chunk_template` and exclude unnecessary metadata fields like `token_count`. - `llama_stack/providers/utils/memory/vector_store.py` - Updated `make_overlapped_chunks` to include metadata serialization and token count for both content and metadata. - Added error handling for metadata serialization issues. - `pyproject.toml` - Added `pydantic.field_validator` as a recognized `classmethod` decorator in the linting configuration. - `tests/integration/tool_runtime/test_rag_tool.py` - Refactored test assertions to separate `assert_valid_chunk_response` and `assert_valid_text_response`. - Added integration tests to validate `chunk_template` functionality with and without metadata inclusion. - Included a test case to ensure `chunk_template` validation errors are raised appropriately. - `tests/unit/rag/test_vector_store.py` - Added unit tests for `make_overlapped_chunks`, verifying chunk creation with overlapping tokens and metadata integrity. - Added tests to handle metadata serialization errors, ensuring proper exception handling. - `docs/_static/llama-stack-spec.html` - Added a new `chunk_template` field of type `string` with a default template for formatting retrieved chunks in RAGQueryConfig. - Updated the `required` fields to include `chunk_template`. - `docs/_static/llama-stack-spec.yaml` - Introduced `chunk_template` field with a default value for RAGQueryConfig. - Updated the required configuration list to include `chunk_template`. - `docs/source/building_applications/rag.md` - Documented the `chunk_template` configuration, explaining how to customize metadata formatting in RAG queries. - Added examples demonstrating the usage of the `chunk_template` field in RAG tool queries. - Highlighted default values for `RAG` agent configurations. # Resolves https://github.com/meta-llama/llama-stack/issues/1767 ## Test Plan Updated both `test_vector_store.py` and `test_rag_tool.py` and tested end-to-end with a script. I also tested the quickstart to enable this and specified this metadata: ```python document = RAGDocument( document_id="document_1", content=source, mime_type="text/html", metadata={"author": "Paul Graham", "title": "How to do great work"}, ) ``` Which produced the output below:  This highlights the usefulness of the additional metadata. Notice how the metadata is redundant for different chunks of the same document. I think we can update that in a subsequent PR. # Documentation I've added a brief comment about this in the documentation to outline this to users and updated the API documentation. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
233 lines
8 KiB
Python
233 lines
8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
from llama_stack_client.types import Document
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client_with_empty_registry(client_with_models):
|
|
def clear_registry():
|
|
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
|
|
for vector_db_id in vector_dbs:
|
|
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
|
|
|
clear_registry()
|
|
yield client_with_models
|
|
|
|
# you must clean after the last test if you were running tests against
|
|
# a stateful server instance
|
|
clear_registry()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_documents():
|
|
return [
|
|
Document(
|
|
document_id="test-doc-1",
|
|
content="Python is a high-level programming language.",
|
|
metadata={"category": "programming", "difficulty": "beginner"},
|
|
),
|
|
Document(
|
|
document_id="test-doc-2",
|
|
content="Machine learning is a subset of artificial intelligence.",
|
|
metadata={"category": "AI", "difficulty": "advanced"},
|
|
),
|
|
Document(
|
|
document_id="test-doc-3",
|
|
content="Data structures are fundamental to computer science.",
|
|
metadata={"category": "computer science", "difficulty": "intermediate"},
|
|
),
|
|
Document(
|
|
document_id="test-doc-4",
|
|
content="Neural networks are inspired by biological neural networks.",
|
|
metadata={"category": "AI", "difficulty": "advanced"},
|
|
),
|
|
]
|
|
|
|
|
|
def assert_valid_chunk_response(response):
|
|
assert len(response.chunks) > 0
|
|
assert len(response.scores) > 0
|
|
assert len(response.chunks) == len(response.scores)
|
|
for chunk in response.chunks:
|
|
assert isinstance(chunk.content, str)
|
|
|
|
|
|
def assert_valid_text_response(response):
|
|
assert len(response.content) > 0
|
|
assert all(isinstance(chunk.text, str) for chunk in response.content)
|
|
|
|
|
|
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
|
vector_db_id = "test_vector_db"
|
|
client_with_empty_registry.vector_dbs.register(
|
|
vector_db_id=vector_db_id,
|
|
embedding_model=embedding_model_id,
|
|
embedding_dimension=384,
|
|
)
|
|
|
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
|
documents=sample_documents,
|
|
chunk_size_in_tokens=512,
|
|
vector_db_id=vector_db_id,
|
|
)
|
|
|
|
# Query with a direct match
|
|
query1 = "programming language"
|
|
response1 = client_with_empty_registry.vector_io.query(
|
|
vector_db_id=vector_db_id,
|
|
query=query1,
|
|
)
|
|
assert_valid_chunk_response(response1)
|
|
assert any("Python" in chunk.content for chunk in response1.chunks)
|
|
|
|
# Query with semantic similarity
|
|
query2 = "AI and brain-inspired computing"
|
|
response2 = client_with_empty_registry.vector_io.query(
|
|
vector_db_id=vector_db_id,
|
|
query=query2,
|
|
)
|
|
assert_valid_chunk_response(response2)
|
|
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
|
|
|
|
# Query with limit on number of results (max_chunks=2)
|
|
query3 = "computer"
|
|
response3 = client_with_empty_registry.vector_io.query(
|
|
vector_db_id=vector_db_id,
|
|
query=query3,
|
|
params={"max_chunks": 2},
|
|
)
|
|
assert_valid_chunk_response(response3)
|
|
assert len(response3.chunks) <= 2
|
|
|
|
# Query with threshold on similarity score
|
|
query4 = "computer"
|
|
response4 = client_with_empty_registry.vector_io.query(
|
|
vector_db_id=vector_db_id,
|
|
query=query4,
|
|
params={"score_threshold": 0.01},
|
|
)
|
|
assert_valid_chunk_response(response4)
|
|
assert all(score >= 0.01 for score in response4.scores)
|
|
|
|
|
|
def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
|
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
|
assert len(providers) > 0
|
|
|
|
vector_db_id = "test_vector_db"
|
|
|
|
client_with_empty_registry.vector_dbs.register(
|
|
vector_db_id=vector_db_id,
|
|
embedding_model=embedding_model_id,
|
|
embedding_dimension=384,
|
|
)
|
|
|
|
# list to check memory bank is successfully registered
|
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
|
assert vector_db_id in available_vector_dbs
|
|
|
|
urls = [
|
|
"memory_optimizations.rst",
|
|
"chat.rst",
|
|
"llama3.rst",
|
|
]
|
|
documents = [
|
|
Document(
|
|
document_id=f"num-{i}",
|
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
|
mime_type="text/plain",
|
|
metadata={},
|
|
)
|
|
for i, url in enumerate(urls)
|
|
]
|
|
|
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
|
documents=documents,
|
|
vector_db_id=vector_db_id,
|
|
chunk_size_in_tokens=512,
|
|
)
|
|
|
|
# Query for the name of method
|
|
response1 = client_with_empty_registry.vector_io.query(
|
|
vector_db_id=vector_db_id,
|
|
query="What's the name of the fine-tunning method used?",
|
|
)
|
|
assert_valid_chunk_response(response1)
|
|
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
|
|
|
# Query for the name of model
|
|
response2 = client_with_empty_registry.vector_io.query(
|
|
vector_db_id=vector_db_id,
|
|
query="Which Llama model is mentioned?",
|
|
)
|
|
assert_valid_chunk_response(response2)
|
|
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
|
|
|
|
|
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id):
|
|
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
|
assert len(providers) > 0
|
|
|
|
vector_db_id = "test_vector_db"
|
|
|
|
client_with_empty_registry.vector_dbs.register(
|
|
vector_db_id=vector_db_id,
|
|
embedding_model=embedding_model_id,
|
|
embedding_dimension=384,
|
|
)
|
|
|
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
|
assert vector_db_id in available_vector_dbs
|
|
|
|
urls = [
|
|
"memory_optimizations.rst",
|
|
"chat.rst",
|
|
"llama3.rst",
|
|
]
|
|
documents = [
|
|
Document(
|
|
document_id=f"num-{i}",
|
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
|
mime_type="text/plain",
|
|
metadata={"author": "llama", "source": url},
|
|
)
|
|
for i, url in enumerate(urls)
|
|
]
|
|
|
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
|
documents=documents,
|
|
vector_db_id=vector_db_id,
|
|
chunk_size_in_tokens=512,
|
|
)
|
|
|
|
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
|
vector_db_ids=[vector_db_id],
|
|
content="What is the name of the method used for fine-tuning?",
|
|
)
|
|
assert_valid_text_response(response_with_metadata)
|
|
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
|
|
|
|
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
|
vector_db_ids=[vector_db_id],
|
|
content="What is the name of the method used for fine-tuning?",
|
|
query_config={
|
|
"include_metadata_in_content": True,
|
|
"chunk_template": "Result {index}\nContent: {chunk.content}\n",
|
|
},
|
|
)
|
|
assert_valid_text_response(response_without_metadata)
|
|
assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
|
|
|
|
with pytest.raises(ValueError):
|
|
client_with_empty_registry.tool_runtime.rag_tool.query(
|
|
vector_db_ids=[vector_db_id],
|
|
content="What is the name of the method used for fine-tuning?",
|
|
query_config={
|
|
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
|
},
|
|
)
|