mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
feat: Adding support for metadata in RAG insertion and querying
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
473a07f624
commit
e50a546bc0
8 changed files with 149 additions and 25 deletions
7
docs/_static/llama-stack-spec.html
vendored
7
docs/_static/llama-stack-spec.html
vendored
|
@ -11133,13 +11133,18 @@
|
||||||
"max_chunks": {
|
"max_chunks": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"default": 5
|
"default": 5
|
||||||
|
},
|
||||||
|
"include_metadata_in_content": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"query_generator_config",
|
"query_generator_config",
|
||||||
"max_tokens_in_context",
|
"max_tokens_in_context",
|
||||||
"max_chunks"
|
"max_chunks",
|
||||||
|
"include_metadata_in_content"
|
||||||
],
|
],
|
||||||
"title": "RAGQueryConfig"
|
"title": "RAGQueryConfig"
|
||||||
},
|
},
|
||||||
|
|
4
docs/_static/llama-stack-spec.yaml
vendored
4
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7690,11 +7690,15 @@ components:
|
||||||
max_chunks:
|
max_chunks:
|
||||||
type: integer
|
type: integer
|
||||||
default: 5
|
default: 5
|
||||||
|
include_metadata_in_content:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
- max_tokens_in_context
|
- max_tokens_in_context
|
||||||
- max_chunks
|
- max_chunks
|
||||||
|
- include_metadata_in_content
|
||||||
title: RAGQueryConfig
|
title: RAGQueryConfig
|
||||||
RAGQueryGeneratorConfig:
|
RAGQueryGeneratorConfig:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
|
|
@ -51,6 +51,7 @@ chunks = [
|
||||||
"mime_type": "text/plain",
|
"mime_type": "text/plain",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"document_id": "doc1",
|
"document_id": "doc1",
|
||||||
|
"author": "Jane Doe",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can configure adding metadata to the context if you find it useful for your application. Simply add:
|
||||||
|
```python
|
||||||
|
# Query documents
|
||||||
|
results = client.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="What do you know about...",
|
||||||
|
query_config={
|
||||||
|
"include_metadata_in_content": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
### Building RAG-Enhanced Agents
|
### Building RAG-Enhanced Agents
|
||||||
|
|
||||||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||||
|
@ -115,6 +127,12 @@ agent = Agent(
|
||||||
"name": "builtin::rag/knowledge_search",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
|
# Defaults
|
||||||
|
"query_config": {
|
||||||
|
"chunk_size_in_tokens": 512,
|
||||||
|
"chunk_overlap_in_tokens": 0,
|
||||||
|
"include_metadata_in_content": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -72,6 +72,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
|
include_metadata_in_content: bool = False
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -87,6 +87,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
content,
|
content,
|
||||||
chunk_size_in_tokens,
|
chunk_size_in_tokens,
|
||||||
chunk_size_in_tokens // 4,
|
chunk_size_in_tokens // 4,
|
||||||
|
doc.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -140,19 +141,29 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
for i, c in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
metadata = c.metadata
|
metadata = chunk.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
|
if query_config.include_metadata_in_content:
|
||||||
|
tokens += metadata["metadata_token_count"]
|
||||||
|
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
log.error(
|
log.error(
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
picked.append(
|
|
||||||
TextContentItem(
|
text_content = f"Result {i + 1}:\n"
|
||||||
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
|
if query_config.include_metadata_in_content:
|
||||||
)
|
metadata_subset = {
|
||||||
)
|
k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]
|
||||||
|
}
|
||||||
|
text_content += f"\nMetadata: {metadata_subset}"
|
||||||
|
else:
|
||||||
|
text_content += f"Document_id:{metadata['document_id'][:5]}"
|
||||||
|
text_content += f"\nContent: {chunk.content}\n"
|
||||||
|
picked.append(TextContentItem(text=text_content))
|
||||||
|
|
||||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
picked.append(
|
picked.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
|
|
|
@ -139,22 +139,27 @@ async def content_from_doc(doc: RAGDocument) -> str:
|
||||||
return interleaved_content_as_str(doc.content)
|
return interleaved_content_as_str(doc.content)
|
||||||
|
|
||||||
|
|
||||||
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]:
|
def make_overlapped_chunks(
|
||||||
|
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
|
||||||
|
) -> list[Chunk]:
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||||
|
metadata_tokens = tokenizer.encode(str(metadata), bos=False, eos=False)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in range(0, len(tokens), window_len - overlap_len):
|
for i in range(0, len(tokens), window_len - overlap_len):
|
||||||
toks = tokens[i : i + window_len]
|
toks = tokens[i : i + window_len]
|
||||||
chunk = tokenizer.decode(toks)
|
chunk = tokenizer.decode(toks)
|
||||||
|
chunk_metadata = metadata.copy()
|
||||||
|
chunk_metadata["document_id"] = document_id
|
||||||
|
chunk_metadata["token_count"] = len(toks)
|
||||||
|
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
|
||||||
|
|
||||||
# chunk is a string
|
# chunk is a string
|
||||||
chunks.append(
|
chunks.append(
|
||||||
Chunk(
|
Chunk(
|
||||||
content=chunk,
|
content=chunk,
|
||||||
metadata={
|
metadata=chunk_metadata,
|
||||||
"token_count": len(toks),
|
|
||||||
"document_id": document_id,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response):
|
def assert_valid_chunk_response(response):
|
||||||
assert len(response.chunks) > 0
|
assert len(response.chunks) > 0
|
||||||
assert len(response.scores) > 0
|
assert len(response.scores) > 0
|
||||||
assert len(response.chunks) == len(response.scores)
|
assert len(response.chunks) == len(response.scores)
|
||||||
|
@ -57,6 +57,11 @@ def assert_valid_response(response):
|
||||||
assert isinstance(chunk.content, str)
|
assert isinstance(chunk.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_valid_text_response(response):
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert all(isinstance(chunk.text, str) for chunk in response.content)
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
||||||
vector_db_id = "test_vector_db"
|
vector_db_id = "test_vector_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
@ -77,7 +82,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
query=query1,
|
query=query1,
|
||||||
)
|
)
|
||||||
assert_valid_response(response1)
|
assert_valid_chunk_response(response1)
|
||||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||||
|
|
||||||
# Query with semantic similarity
|
# Query with semantic similarity
|
||||||
|
@ -86,7 +91,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
query=query2,
|
query=query2,
|
||||||
)
|
)
|
||||||
assert_valid_response(response2)
|
assert_valid_chunk_response(response2)
|
||||||
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
|
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
|
||||||
|
|
||||||
# Query with limit on number of results (max_chunks=2)
|
# Query with limit on number of results (max_chunks=2)
|
||||||
|
@ -96,7 +101,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
|
||||||
query=query3,
|
query=query3,
|
||||||
params={"max_chunks": 2},
|
params={"max_chunks": 2},
|
||||||
)
|
)
|
||||||
assert_valid_response(response3)
|
assert_valid_chunk_response(response3)
|
||||||
assert len(response3.chunks) <= 2
|
assert len(response3.chunks) <= 2
|
||||||
|
|
||||||
# Query with threshold on similarity score
|
# Query with threshold on similarity score
|
||||||
|
@ -106,7 +111,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
|
||||||
query=query4,
|
query=query4,
|
||||||
params={"score_threshold": 0.01},
|
params={"score_threshold": 0.01},
|
||||||
)
|
)
|
||||||
assert_valid_response(response4)
|
assert_valid_chunk_response(response4)
|
||||||
assert all(score >= 0.01 for score in response4.scores)
|
assert all(score >= 0.01 for score in response4.scores)
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,9 +131,6 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
|
||||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
assert vector_db_id in available_vector_dbs
|
assert vector_db_id in available_vector_dbs
|
||||||
|
|
||||||
# URLs of documents to insert
|
|
||||||
# TODO: Move to test/memory/resources then update the url to
|
|
||||||
# https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url}
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
"chat.rst",
|
"chat.rst",
|
||||||
|
@ -139,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
|
||||||
document_id=f"num-{i}",
|
document_id=f"num-{i}",
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
metadata={},
|
metadata={"author": "llama", "source": url},
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
@ -155,7 +157,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
query="What's the name of the fine-tunning method used?",
|
query="What's the name of the fine-tunning method used?",
|
||||||
)
|
)
|
||||||
assert_valid_response(response1)
|
assert_valid_chunk_response(response1)
|
||||||
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
||||||
|
|
||||||
# Query for the name of model
|
# Query for the name of model
|
||||||
|
@ -163,5 +165,52 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
query="Which Llama model is mentioned?",
|
query="Which Llama model is mentioned?",
|
||||||
)
|
)
|
||||||
assert_valid_response(response2)
|
assert_valid_chunk_response(response2)
|
||||||
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id):
|
||||||
|
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||||
|
assert len(providers) > 0
|
||||||
|
|
||||||
|
vector_db_id = "test_vector_db"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=384,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
assert vector_db_id in available_vector_dbs
|
||||||
|
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
]
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={"author": "llama", "source": url},
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="What is the name of the method used for fine-tuning?",
|
||||||
|
query_config={
|
||||||
|
"include_metadata_in_content": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
assert any("metadata:" in chunk.text.lower() for chunk in response.content)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.tools import RAGDocument
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc
|
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks
|
||||||
|
|
||||||
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
||||||
# Depending on the machine, this can get parsed a couple of ways
|
# Depending on the machine, this can get parsed a couple of ways
|
||||||
|
@ -76,3 +76,34 @@ class TestVectorStore:
|
||||||
)
|
)
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"window_len, overlap_len, expected_chunks",
|
||||||
|
[
|
||||||
|
(5, 2, 4), # Create 4 chunks with window of 5 and overlap of 2
|
||||||
|
(4, 1, 4), # Create 4 chunks with window of 3 and overlap of 1
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_make_overlapped_chunks(self, window_len, overlap_len, expected_chunks):
|
||||||
|
document_id = "test_doc_123"
|
||||||
|
text = "This is a sample document for testing the chunking behavior"
|
||||||
|
original_metadata = {"source": "test", "date": "2023-01-01", "author": "llama"}
|
||||||
|
len_metadata_tokens = 24 # specific to the metadata above
|
||||||
|
|
||||||
|
chunks = make_overlapped_chunks(document_id, text, window_len, overlap_len, original_metadata)
|
||||||
|
|
||||||
|
assert len(chunks) == expected_chunks
|
||||||
|
|
||||||
|
# Check that each chunk has the right metadata
|
||||||
|
for chunk in chunks:
|
||||||
|
# Original metadata should be preserved
|
||||||
|
assert chunk.metadata["source"] == "test"
|
||||||
|
assert chunk.metadata["date"] == "2023-01-01"
|
||||||
|
assert chunk.metadata["author"] == "llama"
|
||||||
|
|
||||||
|
# New metadata should be added
|
||||||
|
assert chunk.metadata["document_id"] == document_id
|
||||||
|
assert "token_count" in chunk.metadata
|
||||||
|
assert isinstance(chunk.metadata["token_count"], int)
|
||||||
|
assert chunk.metadata["token_count"] > 0
|
||||||
|
assert chunk.metadata["metadata_token_count"] == len_metadata_tokens
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue