[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:04:16 -08:00 committed by GitHub
parent 78a481bb22
commit 1a7490470a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1648 additions and 1345 deletions

View file

@ -19,7 +19,6 @@ import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray
from pydantic import BaseModel, Field
from pypdf import PdfReader
from llama_stack.apis.common.content_types import (
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
TextContentItem,
URL,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import Api
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
log = logging.getLogger(__name__)
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
return ret
async def content_from_doc(doc: MemoryBankDocument) -> str:
async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
@ -161,7 +153,13 @@ def make_overlapped_chunks(
chunk = tokenizer.decode(toks)
# chunk is a string
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
)
)
return chunks