mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
[memory refactor][5/n] Migrate all vector_io providers (#835)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. This PR finishes off all the stragglers and migrates everything to the new naming.
This commit is contained in:
parent
63f37f9b7c
commit
c9e5578151
78 changed files with 504 additions and 623 deletions
|
@ -8,10 +8,7 @@ import os
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.memory import MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
||||
from llama_stack.apis.tools import ToolInvocationResult
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
|
@ -36,7 +33,7 @@ def sample_documents():
|
|||
"lora_finetune.rst",
|
||||
]
|
||||
return [
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
|
@ -57,7 +54,7 @@ class TestTools:
|
|||
|
||||
# Execute the tool
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="web_search", args={"query": sample_search_query}
|
||||
tool_name="web_search", kwargs={"query": sample_search_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
|
@ -75,7 +72,7 @@ class TestTools:
|
|||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
|
||||
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
|
@ -85,43 +82,33 @@ class TestTools:
|
|||
assert isinstance(response.content, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_tool(self, tools_stack, sample_documents):
|
||||
async def test_rag_tool(self, tools_stack, sample_documents):
|
||||
"""Test the memory tool functionality."""
|
||||
memory_banks_impl = tools_stack.impls[Api.memory_banks]
|
||||
memory_impl = tools_stack.impls[Api.memory]
|
||||
vector_dbs_impl = tools_stack.impls[Api.vector_dbs]
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
# Register memory bank
|
||||
await memory_banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
await vector_dbs_impl.register(
|
||||
vector_db_id="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_id="faiss",
|
||||
)
|
||||
|
||||
# Insert documents into memory
|
||||
await memory_impl.insert_documents(
|
||||
bank_id="test_bank",
|
||||
await tools_impl.rag_tool.insert_documents(
|
||||
documents=sample_documents,
|
||||
vector_db_id="test_bank",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
# Execute the memory tool
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="memory",
|
||||
args={
|
||||
"messages": [
|
||||
UserMessage(
|
||||
content="What are the main topics covered in the documentation?",
|
||||
)
|
||||
],
|
||||
"memory_bank_ids": ["test_bank"],
|
||||
},
|
||||
response = await tools_impl.rag_tool.query_context(
|
||||
content="What are the main topics covered in the documentation?",
|
||||
vector_db_ids=["test_bank"],
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert isinstance(response, RAGQueryResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue