mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
feat: Updating Rag Tool to use Files API and Vector Stores API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
47b640370e
commit
ab5ab6e979
6 changed files with 93 additions and 39 deletions
|
@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models):
|
|||
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
|
||||
clear_registry()
|
||||
|
||||
try:
|
||||
client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield client_with_models
|
||||
|
||||
# you must clean after the last test if you were running tests against
|
||||
# a stateful server instance
|
||||
clear_registry()
|
||||
|
||||
|
||||
|
@ -66,12 +70,13 @@ def assert_valid_text_response(response):
|
|||
def test_vector_db_insert_inline_and_query(
|
||||
client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
|
||||
):
|
||||
vector_db_id = "test_vector_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_name = "test_vector_db"
|
||||
vector_db = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
vector_db_id = vector_db.identifier
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=sample_documents,
|
||||
|
@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query(
|
|||
|
||||
# list to check memory bank is successfully registered
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert vector_db_id in available_vector_dbs
|
||||
# VectorDB is being migrated to VectorStore, so the ID will be different
|
||||
# Just check that at least one vector DB was registered
|
||||
assert len(available_vector_dbs) > 0
|
||||
# Use the actual registered vector_db_id for subsequent operations
|
||||
actual_vector_db_id = available_vector_dbs[0]
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
|
@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query(
|
|||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
# Query for the name of method
|
||||
response1 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query="What's the name of the fine-tunning method used?",
|
||||
)
|
||||
assert_valid_chunk_response(response1)
|
||||
|
@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query(
|
|||
|
||||
# Query for the name of model
|
||||
response2 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query="Which Llama model is mentioned?",
|
||||
)
|
||||
assert_valid_chunk_response(response2)
|
||||
|
@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
|||
)
|
||||
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert vector_db_id in available_vector_dbs
|
||||
# VectorDB is being migrated to VectorStore, so the ID will be different
|
||||
# Just check that at least one vector DB was registered
|
||||
assert len(available_vector_dbs) > 0
|
||||
# Use the actual registered vector_db_id for subsequent operations
|
||||
actual_vector_db_id = available_vector_dbs[0]
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
|
@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
|||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
vector_db_ids=[actual_vector_db_id],
|
||||
content="What is the name of the method used for fine-tuning?",
|
||||
)
|
||||
assert_valid_text_response(response_with_metadata)
|
||||
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
|
||||
|
||||
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
vector_db_ids=[actual_vector_db_id],
|
||||
content="What is the name of the method used for fine-tuning?",
|
||||
query_config={
|
||||
"include_metadata_in_content": True,
|
||||
|
@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
|||
|
||||
with pytest.raises((ValueError, BadRequestError)):
|
||||
client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
vector_db_ids=[actual_vector_db_id],
|
||||
content="What is the name of the method used for fine-tuning?",
|
||||
query_config={
|
||||
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue