chore: Updating documentation, adding exception handling for Vector Stores in RAG Tool, more tests on migration, and migrate off of inference_api for context_retriever for RAG (#3367)

# What does this PR do?

- Updating documentation on migration from RAG Tool to Vector Stores and
Files APIs
- Adding exception handling for Vector Stores in RAG Tool
- Add more tests on migration from RAG Tool to Vector Stores
- Migrate off of inference_api for context_retriever for RAG

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
Integration and unit tests added

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-09-11 06:20:11 -06:00 committed by GitHub
parent f31bcc11bc
commit d15368a302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 355 additions and 45 deletions

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
content = template.render({"messages": messages})
rendered_content: str = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)
response = await inference_api.chat_completion(
model_id=model,
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
model=model,
messages=[message],
stream=False,
)
query = response.completion_message.content
query = response.choices[0].message.content
return query

View file

@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
parse_data_url,
)
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return
for doc in documents:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
mime_type = parts["mimetype"]
else:
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
try:
try:
file_data, mime_type = await raw_data_from_doc(doc)
except Exception as e:
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
continue
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
try:
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
)
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
try:
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query(
self,
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
return ToolInvocationResult(
content=result.content,
content=result.content or [],
metadata=result.metadata,
)