feat: Adding support for metadata in RAG insertion and querying

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-05-09 23:38:47 -04:00
parent 473a07f624
commit e50a546bc0
8 changed files with 149 additions and 25 deletions

View file

@ -49,7 +49,7 @@ def sample_documents():
]
def assert_valid_response(response):
def assert_valid_chunk_response(response):
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
@ -57,6 +57,11 @@ def assert_valid_response(response):
assert isinstance(chunk.content, str)
def assert_valid_text_response(response):
assert len(response.content) > 0
assert all(isinstance(chunk.text, str) for chunk in response.content)
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
@ -77,7 +82,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
vector_db_id=vector_db_id,
query=query1,
)
assert_valid_response(response1)
assert_valid_chunk_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Query with semantic similarity
@ -86,7 +91,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
vector_db_id=vector_db_id,
query=query2,
)
assert_valid_response(response2)
assert_valid_chunk_response(response2)
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
# Query with limit on number of results (max_chunks=2)
@ -96,7 +101,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
query=query3,
params={"max_chunks": 2},
)
assert_valid_response(response3)
assert_valid_chunk_response(response3)
assert len(response3.chunks) <= 2
# Query with threshold on similarity score
@ -106,7 +111,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
query=query4,
params={"score_threshold": 0.01},
)
assert_valid_response(response4)
assert_valid_chunk_response(response4)
assert all(score >= 0.01 for score in response4.scores)
@ -126,9 +131,6 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs
# URLs of documents to insert
# TODO: Move to test/memory/resources then update the url to
# https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url}
urls = [
"memory_optimizations.rst",
"chat.rst",
@ -139,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
metadata={"author": "llama", "source": url},
)
for i, url in enumerate(urls)
]
@ -155,7 +157,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
vector_db_id=vector_db_id,
query="What's the name of the fine-tunning method used?",
)
assert_valid_response(response1)
assert_valid_chunk_response(response1)
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
# Query for the name of model
@ -163,5 +165,52 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
vector_db_id=vector_db_id,
query="Which Llama model is mentioned?",
)
assert_valid_response(response2)
assert_valid_chunk_response(response2)
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
assert len(providers) > 0
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={"author": "llama", "source": url},
)
for i, url in enumerate(urls)
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What is the name of the method used for fine-tuning?",
query_config={
"include_metadata_in_content": True,
},
)
assert_valid_text_response(response)
assert any("metadata:" in chunk.text.lower() for chunk in response.content)