[memory refactor][5/n] Migrate all vector_io providers (#835)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

This PR finishes off all the stragglers and migrates everything to the
new naming.
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:17:59 -08:00 committed by GitHub
parent 63f37f9b7c
commit c9e5578151
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
78 changed files with 504 additions and 623 deletions

View file

@ -53,7 +53,7 @@ async def eval_stack(
"inference",
"agents",
"safety",
"memory",
"vector_io",
"tool_runtime",
]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
@ -69,7 +69,7 @@ async def eval_stack(
Api.scoring,
Api.agents,
Api.safety,
Api.memory,
Api.vector_io,
Api.tool_runtime,
],
providers,

View file

@ -83,7 +83,7 @@ async def tools_stack(
providers = {}
provider_data = {}
for key in ["inference", "memory", "tool_runtime"]:
for key in ["inference", "vector_io", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -117,7 +117,12 @@ async def tools_stack(
)
test_stack = await construct_stack_for_test(
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
[
Api.tool_groups,
Api.inference,
Api.vector_io,
Api.tool_runtime,
],
providers,
provider_data,
models=models,

View file

@ -8,10 +8,7 @@ import os
import pytest
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult
from llama_stack.providers.datatypes import Api
@ -36,7 +33,7 @@ def sample_documents():
"lora_finetune.rst",
]
return [
MemoryBankDocument(
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
@ -57,7 +54,7 @@ class TestTools:
# Execute the tool
response = await tools_impl.invoke_tool(
tool_name="web_search", args={"query": sample_search_query}
tool_name="web_search", kwargs={"query": sample_search_query}
)
# Verify the response
@ -75,7 +72,7 @@ class TestTools:
tools_impl = tools_stack.impls[Api.tool_runtime]
response = await tools_impl.invoke_tool(
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
)
# Verify the response
@ -85,43 +82,33 @@ class TestTools:
assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_memory_tool(self, tools_stack, sample_documents):
async def test_rag_tool(self, tools_stack, sample_documents):
"""Test the memory tool functionality."""
memory_banks_impl = tools_stack.impls[Api.memory_banks]
memory_impl = tools_stack.impls[Api.memory]
vector_dbs_impl = tools_stack.impls[Api.vector_dbs]
tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
await vector_dbs_impl.register(
vector_db_id="test_bank",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
)
# Insert documents into memory
await memory_impl.insert_documents(
bank_id="test_bank",
await tools_impl.rag_tool.insert_documents(
documents=sample_documents,
vector_db_id="test_bank",
chunk_size_in_tokens=512,
)
# Execute the memory tool
response = await tools_impl.invoke_tool(
tool_name="memory",
args={
"messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
"memory_bank_ids": ["test_bank"],
},
response = await tools_impl.rag_tool.query_context(
content="What are the main topics covered in the documentation?",
vector_db_ids=["test_bank"],
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert isinstance(response, RAGQueryResult)
assert response.content is not None
assert len(response.content) > 0