[memory refactor][5/n] Migrate all vector_io providers (#835)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

This PR finishes off all the stragglers and migrates everything to the
new naming.
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:17:59 -08:00 committed by GitHub
parent 63f37f9b7c
commit c9e5578151
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
78 changed files with 504 additions and 623 deletions

View file

@ -8,10 +8,7 @@ import os
import pytest
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult
from llama_stack.providers.datatypes import Api
@ -36,7 +33,7 @@ def sample_documents():
"lora_finetune.rst",
]
return [
MemoryBankDocument(
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
@ -57,7 +54,7 @@ class TestTools:
# Execute the tool
response = await tools_impl.invoke_tool(
tool_name="web_search", args={"query": sample_search_query}
tool_name="web_search", kwargs={"query": sample_search_query}
)
# Verify the response
@ -75,7 +72,7 @@ class TestTools:
tools_impl = tools_stack.impls[Api.tool_runtime]
response = await tools_impl.invoke_tool(
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
)
# Verify the response
@ -85,43 +82,33 @@ class TestTools:
assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_memory_tool(self, tools_stack, sample_documents):
async def test_rag_tool(self, tools_stack, sample_documents):
"""Test the memory tool functionality."""
memory_banks_impl = tools_stack.impls[Api.memory_banks]
memory_impl = tools_stack.impls[Api.memory]
vector_dbs_impl = tools_stack.impls[Api.vector_dbs]
tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
await vector_dbs_impl.register(
vector_db_id="test_bank",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
)
# Insert documents into memory
await memory_impl.insert_documents(
bank_id="test_bank",
await tools_impl.rag_tool.insert_documents(
documents=sample_documents,
vector_db_id="test_bank",
chunk_size_in_tokens=512,
)
# Execute the memory tool
response = await tools_impl.invoke_tool(
tool_name="memory",
args={
"messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
"memory_bank_ids": ["test_bank"],
},
response = await tools_impl.rag_tool.query_context(
content="What are the main topics covered in the documentation?",
vector_db_ids=["test_bank"],
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert isinstance(response, RAGQueryResult)
assert response.content is not None
assert len(response.content) > 0