mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 01:53:10 +00:00
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
78a481bb22
commit
1a7490470a
33 changed files with 1648 additions and 1345 deletions
|
@ -12,10 +12,10 @@ from ..conftest import (
|
|||
get_test_config_for_api,
|
||||
)
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
|
||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
|
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "ollama",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "fireworks",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "remote",
|
||||
"safety": "remote",
|
||||
"memory": "remote",
|
||||
"vector_io": "remote",
|
||||
"agents": "remote",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
|
|||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ async def agents_stack(
|
|||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
|
@ -118,7 +118,7 @@ async def agents_stack(
|
|||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
|
|
|
@ -214,9 +214,11 @@ class TestAgents:
|
|||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
# FIXME: we need to check the content of the turn response and ensure
|
||||
# RAG actually worked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
|
|
|
@ -8,13 +8,12 @@ import uuid
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
make_overlapped_chunks,
|
||||
MemoryBankDocument,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
docs = [
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue