[memory refactor][6/n] Update naming and routes (#839)

Making a few small naming changes as per feedback:

- RAGToolRuntime methods are called `insert` and `query` to keep them
more general
- The tool names are changed to non-namespaced forms
`insert_into_memory` and `query_from_memory`
- The REST endpoints are more REST-ful
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:39:13 -08:00 committed by GitHub
parent c9e5578151
commit a63a43c646
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 319 additions and 330 deletions

View file

@ -74,8 +74,8 @@ class RAGQueryConfig(BaseModel):
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
async def insert_documents(
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
async def insert(
self,
documents: List[RAGDocument],
vector_db_id: str,
@ -84,12 +84,12 @@ class RAGToolRuntime(Protocol):
"""Index documents so they can be used by the RAG system"""
...
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
async def query_context(
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
async def query(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent"""
...

View file

@ -38,7 +38,7 @@ class VectorDBStore(Protocol):
class VectorIO(Protocol):
vector_db_store: VectorDBStore
# this will just block now until documents are inserted, but it should
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/vector-io/insert", method="POST")
async def insert_chunks(

View file

@ -414,25 +414,25 @@ class ToolRuntimeRouter(ToolRuntime):
) -> None:
self.routing_table = routing_table
async def query_context(
async def query(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"rag_tool.query_context"
).query_context(content, query_config, vector_db_ids)
"query_from_memory"
).query(content, vector_db_ids, query_config)
async def insert_documents(
async def insert(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
"insert_into_memory"
).insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
@ -441,10 +441,9 @@ class ToolRuntimeRouter(ToolRuntime):
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
pass

View file

@ -84,7 +84,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "rag_tool.query_context"
MEMORY_QUERY_TOOL = "query_from_memory"
WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory"
@ -432,16 +432,16 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
result = await self.tool_runtime_api.rag_tool.query_context(
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content(
[msg.content for msg in input_messages]
),
vector_db_ids=vector_db_ids,
query_config=RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=5,
),
vector_db_ids=vector_db_ids,
)
retrieved_context = result.content
@ -882,7 +882,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in data
]
await self.tool_runtime_api.rag_tool.insert_documents(
await self.tool_runtime_api.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,

View file

@ -61,7 +61,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self):
pass
async def insert_documents(
async def insert(
self,
documents: List[RAGDocument],
vector_db_id: str,
@ -87,15 +87,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
vector_db_id=vector_db_id,
)
async def query_context(
async def query(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config.query_generator_config,
content,
@ -159,11 +160,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# encountering fatals.
return [
ToolDef(
name="rag_tool.query_context",
name="query_from_memory",
description="Retrieve context from memory",
),
ToolDef(
name="rag_tool.insert_documents",
name="insert_into_memory",
description="Insert documents into memory",
),
]

View file

@ -96,14 +96,14 @@ class TestTools:
)
# Insert documents into memory
await tools_impl.rag_tool.insert_documents(
await tools_impl.rag_tool.insert(
documents=sample_documents,
vector_db_id="test_bank",
chunk_size_in_tokens=512,
)
# Execute the memory tool
response = await tools_impl.rag_tool.query_context(
response = await tools_impl.rag_tool.query(
content="What are the main topics covered in the documentation?",
vector_db_ids=["test_bank"],
)

View file

@ -11,11 +11,9 @@ from pathlib import Path
import pytest
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
MemoryBankDocument,
URL,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_doc, URL
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
@ -41,33 +39,33 @@ class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = MemoryBankDocument(
doc = RAGDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
assert content == "Dumm y PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
doc = RAGDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
assert content == "Dumm y PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
doc = RAGDocument(
document_id="dummy",
content=URL(
uri=url,
@ -76,4 +74,4 @@ class TestVectorStore:
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
assert content == "Dumm y PDF file"