feat: Updating Rag Tool to use Files API and Vector Stores API (#3344)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 18s
Update ReadTheDocs / update-readthedocs (push) Failing after 15s
Python Package Build Test / build (3.13) (push) Failing after 19s
Test External API and Providers / test-external (venv) (push) Failing after 17s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 23s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 22s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Unit Tests / unit-tests (3.13) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (push) Failing after 23s
UI Tests / ui-tests (22) (push) Successful in 44s
Pre-commit / pre-commit (push) Successful in 1m32s

This commit is contained in:
Francisco Arceo 2025-09-06 07:26:34 -06:00 committed by GitHub
parent 47b640370e
commit 7cd1c2c238
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 93 additions and 39 deletions

View file

@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize()
return impl

View file

@ -5,10 +5,15 @@
# the root directory of this source tree.
import asyncio
import base64
import io
import mimetypes
import secrets
import string
from typing import Any
import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
parse_data_url,
)
from .config import RagToolRuntimeConfig
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
config: RagToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
files_api: Files,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
self.files_api = files_api
async def initialize(self):
pass
@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
chunks = []
if not documents:
return
for doc in documents:
content = await content_from_doc(doc)
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
doc.metadata,
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
mime_type = parts["mimetype"]
else:
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
async def query(
self,