Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-11 21:30:50 +09:00 committed by GitHub
commit 60318b659d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 355 additions and 45 deletions

View file

@ -93,10 +93,31 @@ chunks_response = client.vector_io.query(
### Using the RAG Tool ### Using the RAG Tool
> **⚠️ DEPRECATION NOTICE**: The RAG Tool is being deprecated in favor of directly using the OpenAI-compatible Search
> API. We recommend migrating to the OpenAI APIs for better compatibility and future support.
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
[appendix](#more-ragdocument-examples). [appendix](#more-ragdocument-examples).
#### OpenAI API Integration & Migration
The RAG tool has been updated to use OpenAI-compatible APIs. This provides several benefits:
- **Files API Integration**: Documents are now uploaded using OpenAI's file upload endpoints
- **Vector Stores API**: Vector storage operations use OpenAI's vector store format with configurable chunking strategies
- **Error Resilience:** When processing multiple documents, individual failures are logged but don't crash the operation. Failed documents are skipped while successful ones continue processing.
**Migration Path:**
We recommend migrating to the OpenAI-compatible Search API for:
1. **Better OpenAI Ecosystem Integration**: Direct compatibility with OpenAI tools and workflows including the Responses API
2**Future-Proof**: Continued support and feature development
3**Full OpenAI Compatibility**: Vector Stores, Files, and Search APIs are fully compatible with OpenAI's Responses API
The OpenAI APIs are used under the hood, so you can continue to use your existing RAG Tool code with minimal changes.
However, we recommend updating your code to use the new OpenAI-compatible APIs for better long-term support. If any
documents fail to process, they will be logged in the response but will not cause the entire operation to fail.
```python ```python
from llama_stack_client import RAGDocument from llama_stack_client import RAGDocument

View file

@ -8,7 +8,7 @@
from jinja2 import Template from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import ( from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig, DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig,
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
messages = [interleaved_content_as_str(content)] messages = [interleaved_content_as_str(content)]
template = Template(config.template) template = Template(config.template)
content = template.render({"messages": messages}) rendered_content: str = template.render({"messages": messages})
model = config.model model = config.model
message = UserMessage(content=content) message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.chat_completion( response = await inference_api.openai_chat_completion(
model_id=model, model=model,
messages=[message], messages=[message],
stream=False, stream=False,
) )
query = response.completion_message.content query = response.choices[0].message.content
return query return query

View file

@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import parse_data_url
content_from_doc,
parse_data_url,
)
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return return
for doc in documents: for doc in documents:
if isinstance(doc.content, URL): try:
if doc.content.uri.startswith("data:"): try:
parts = parse_data_url(doc.content.uri) file_data, mime_type = await raw_data_from_doc(doc)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode() except Exception as e:
mime_type = parts["mimetype"] log.error(f"Failed to extract content from document {doc.document_id}: {e}")
else: continue
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
file_extension = mimetypes.guess_extension(mime_type) or ".txt" file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data) file_obj = io.BytesIO(file_data)
file_obj.name = filename file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename) upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file( try:
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS created_file = await self.files_api.openai_upload_file(
) file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic( chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig( static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens, max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4, chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
) )
)
await self.vector_io_api.openai_attach_file_to_vector_store( try:
vector_store_id=vector_db_id, await self.vector_io_api.openai_attach_file_to_vector_store(
file_id=created_file.id, vector_store_id=vector_db_id,
attributes=doc.metadata, file_id=created_file.id,
chunking_strategy=chunking_strategy, attributes=doc.metadata,
) chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query( async def query(
self, self,
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
if query_config: if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else: else:
# handle someone passing an empty dict
query_config = RAGQueryConfig() query_config = RAGQueryConfig()
query = kwargs["query"] query = kwargs["query"]
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
) )
return ToolInvocationResult( return ToolInvocationResult(
content=result.content, content=result.content or [],
metadata=result.metadata, metadata=result.metadata,
) )

View file

@ -183,6 +183,110 @@ def test_vector_db_insert_from_url_and_query(
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_openai_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
# different document formats that should work with OpenAI APIs
documents = [
Document(
document_id="text-doc",
content="This is a plain text document about machine learning algorithms.",
metadata={"type": "text", "category": "AI"},
),
Document(
document_id="url-doc",
content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
mime_type="text/plain",
metadata={"type": "url", "source": "pytorch"},
),
Document(
document_id="data-url-doc",
content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
metadata={"type": "data_url", "encoding": "base64"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= len(documents), (
f"Expected at least {len(documents)} files, got {len(files_list.data)}"
)
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about machine learning and deep learning",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "machine learning" in content_text or "deep learning" in content_text
def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_exception_handling"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="valid-doc",
content="This is a valid document that should be processed successfully.",
metadata={"status": "valid"},
),
Document(
document_id="invalid-url-doc",
content="https://nonexistent-domain-12345.com/invalid.txt",
metadata={"status": "invalid_url"},
),
Document(
document_id="another-valid-doc",
content="This is another valid document for testing resilience.",
metadata={"status": "valid"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="valid document",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "valid document" in content_text
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension): def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
assert len(providers) > 0 assert len(providers) > 0
@ -249,3 +353,107 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
"chunk_template": "This should raise a ValueError because it is missing the proper template variables", "chunk_template": "This should raise a ValueError because it is missing the proper template variables",
}, },
) )
def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_query_generation_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="ai-doc",
content="Artificial intelligence and machine learning are transforming technology.",
metadata={"category": "AI"},
),
Document(
document_id="banana-doc",
content="Don't bring a banana to a knife fight.",
metadata={"category": "wisdom"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about AI",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "artificial intelligence" in content_text or "machine learning" in content_text
def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_pdf_data_url_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
sample_pdf = b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
import base64
pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
documents = [
Document(
document_id="test-pdf-data-url",
content=pdf_data_url,
metadata={"type": "pdf", "source": "data_url"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
pdf_file = None
for file in files_list.data:
if file.filename and "test-pdf-data-url" in file.filename:
pdf_file = file
break
assert pdf_file is not None, "PDF file should be found in Files API"
assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="sample title",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "sample title" in content_text or "title" in content_text

View file

@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
# Should raise an exception instead of returning empty string # Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError): with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type) content_from_data_and_mime_type(data, mime_type)
async def test_memory_tool_error_handling():
"""Test that memory tool handles various failures gracefully without crashing."""
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
config = RagToolRuntimeConfig()
memory_tool = MemoryToolRuntimeImpl(
config=config,
vector_io_api=AsyncMock(),
inference_api=AsyncMock(),
files_api=AsyncMock(),
)
docs = [
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
]
mock_file1 = MagicMock()
mock_file1.id = "file_good1"
mock_file2 = MagicMock()
mock_file2.id = "file_good2"
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.side_effect = Exception("Bad URL")
mock_client.return_value.__aenter__.return_value = mock_instance
# won't raise exception despite one document failing
await memory_tool.insert(docs, "vector_store_123")
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2