mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
60318b659d
5 changed files with 355 additions and 45 deletions
|
@ -93,10 +93,31 @@ chunks_response = client.vector_io.query(
|
||||||
|
|
||||||
### Using the RAG Tool
|
### Using the RAG Tool
|
||||||
|
|
||||||
|
> **⚠️ DEPRECATION NOTICE**: The RAG Tool is being deprecated in favor of directly using the OpenAI-compatible Search
|
||||||
|
> API. We recommend migrating to the OpenAI APIs for better compatibility and future support.
|
||||||
|
|
||||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
||||||
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
||||||
[appendix](#more-ragdocument-examples).
|
[appendix](#more-ragdocument-examples).
|
||||||
|
|
||||||
|
#### OpenAI API Integration & Migration
|
||||||
|
|
||||||
|
The RAG tool has been updated to use OpenAI-compatible APIs. This provides several benefits:
|
||||||
|
|
||||||
|
- **Files API Integration**: Documents are now uploaded using OpenAI's file upload endpoints
|
||||||
|
- **Vector Stores API**: Vector storage operations use OpenAI's vector store format with configurable chunking strategies
|
||||||
|
- **Error Resilience:** When processing multiple documents, individual failures are logged but don't crash the operation. Failed documents are skipped while successful ones continue processing.
|
||||||
|
|
||||||
|
**Migration Path:**
|
||||||
|
We recommend migrating to the OpenAI-compatible Search API for:
|
||||||
|
1. **Better OpenAI Ecosystem Integration**: Direct compatibility with OpenAI tools and workflows including the Responses API
|
||||||
|
2**Future-Proof**: Continued support and feature development
|
||||||
|
3**Full OpenAI Compatibility**: Vector Stores, Files, and Search APIs are fully compatible with OpenAI's Responses API
|
||||||
|
|
||||||
|
The OpenAI APIs are used under the hood, so you can continue to use your existing RAG Tool code with minimal changes.
|
||||||
|
However, we recommend updating your code to use the new OpenAI-compatible APIs for better long-term support. If any
|
||||||
|
documents fail to process, they will be logged in the response but will not cause the entire operation to fail.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import RAGDocument
|
from llama_stack_client import RAGDocument
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import UserMessage
|
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||||
from llama_stack.apis.tools.rag_tool import (
|
from llama_stack.apis.tools.rag_tool import (
|
||||||
DefaultRAGQueryGeneratorConfig,
|
DefaultRAGQueryGeneratorConfig,
|
||||||
LLMRAGQueryGeneratorConfig,
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
|
||||||
messages = [interleaved_content_as_str(content)]
|
messages = [interleaved_content_as_str(content)]
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render({"messages": messages})
|
rendered_content: str = template.render({"messages": messages})
|
||||||
|
|
||||||
model = config.model
|
model = config.model
|
||||||
message = UserMessage(content=content)
|
message = OpenAIUserMessageParam(content=rendered_content)
|
||||||
response = await inference_api.chat_completion(
|
response = await inference_api.openai_chat_completion(
|
||||||
model_id=model,
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
query = response.completion_message.content
|
query = response.choices[0].message.content
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|
|
@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||||
content_from_doc,
|
|
||||||
parse_data_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import RagToolRuntimeConfig
|
from .config import RagToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||||
|
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||||
|
if isinstance(doc.content, URL):
|
||||||
|
if doc.content.uri.startswith("data:"):
|
||||||
|
parts = parse_data_url(doc.content.uri)
|
||||||
|
mime_type = parts["mimetype"]
|
||||||
|
data = parts["data"]
|
||||||
|
|
||||||
|
if parts["is_base64"]:
|
||||||
|
file_data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
file_data = data.encode("utf-8")
|
||||||
|
|
||||||
|
return file_data, mime_type
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(doc.content.uri)
|
||||||
|
r.raise_for_status()
|
||||||
|
mime_type = r.headers.get("content-type", "application/octet-stream")
|
||||||
|
return r.content, mime_type
|
||||||
|
else:
|
||||||
|
if isinstance(doc.content, str):
|
||||||
|
content_str = doc.content
|
||||||
|
else:
|
||||||
|
content_str = interleaved_content_as_str(doc.content)
|
||||||
|
|
||||||
|
if content_str.startswith("data:"):
|
||||||
|
parts = parse_data_url(content_str)
|
||||||
|
mime_type = parts["mimetype"]
|
||||||
|
data = parts["data"]
|
||||||
|
|
||||||
|
if parts["is_base64"]:
|
||||||
|
file_data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
file_data = data.encode("utf-8")
|
||||||
|
|
||||||
|
return file_data, mime_type
|
||||||
|
else:
|
||||||
|
return content_str.encode("utf-8"), "text/plain"
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -95,20 +133,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
return
|
return
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
if isinstance(doc.content, URL):
|
try:
|
||||||
if doc.content.uri.startswith("data:"):
|
try:
|
||||||
parts = parse_data_url(doc.content.uri)
|
file_data, mime_type = await raw_data_from_doc(doc)
|
||||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
except Exception as e:
|
||||||
mime_type = parts["mimetype"]
|
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
|
||||||
else:
|
continue
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(doc.content.uri)
|
|
||||||
file_data = response.content
|
|
||||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
|
||||||
else:
|
|
||||||
content_str = await content_from_doc(doc)
|
|
||||||
file_data = content_str.encode("utf-8")
|
|
||||||
mime_type = doc.mime_type or "text/plain"
|
|
||||||
|
|
||||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||||
|
@ -118,9 +148,13 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
|
|
||||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||||
|
|
||||||
|
try:
|
||||||
created_file = await self.files_api.openai_upload_file(
|
created_file = await self.files_api.openai_upload_file(
|
||||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||||
static=VectorStoreChunkingStrategyStaticConfig(
|
static=VectorStoreChunkingStrategyStaticConfig(
|
||||||
|
@ -129,12 +163,22 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||||
vector_store_id=vector_db_id,
|
vector_store_id=vector_db_id,
|
||||||
file_id=created_file.id,
|
file_id=created_file.id,
|
||||||
attributes=doc.metadata,
|
attributes=doc.metadata,
|
||||||
chunking_strategy=chunking_strategy,
|
chunking_strategy=chunking_strategy,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
|
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
if query_config:
|
if query_config:
|
||||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||||
else:
|
else:
|
||||||
# handle someone passing an empty dict
|
|
||||||
query_config = RAGQueryConfig()
|
query_config = RAGQueryConfig()
|
||||||
|
|
||||||
query = kwargs["query"]
|
query = kwargs["query"]
|
||||||
|
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=result.content,
|
content=result.content or [],
|
||||||
metadata=result.metadata,
|
metadata=result.metadata,
|
||||||
)
|
)
|
||||||
|
|
|
@ -183,6 +183,110 @@ def test_vector_db_insert_from_url_and_query(
|
||||||
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_openai_vector_db"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
# different document formats that should work with OpenAI APIs
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="text-doc",
|
||||||
|
content="This is a plain text document about machine learning algorithms.",
|
||||||
|
metadata={"type": "text", "category": "AI"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="url-doc",
|
||||||
|
content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={"type": "url", "source": "pytorch"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="data-url-doc",
|
||||||
|
content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
|
||||||
|
metadata={"type": "data_url", "encoding": "base64"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
files_list = client_with_empty_registry.files.list()
|
||||||
|
assert len(files_list.data) >= len(documents), (
|
||||||
|
f"Expected at least {len(documents)} files, got {len(files_list.data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
|
||||||
|
vector_store_id=actual_vector_db_id
|
||||||
|
)
|
||||||
|
assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="Tell me about machine learning and deep learning",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "machine learning" in content_text or "deep learning" in content_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_exception_handling"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="valid-doc",
|
||||||
|
content="This is a valid document that should be processed successfully.",
|
||||||
|
metadata={"status": "valid"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="invalid-url-doc",
|
||||||
|
content="https://nonexistent-domain-12345.com/invalid.txt",
|
||||||
|
metadata={"status": "invalid_url"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="another-valid-doc",
|
||||||
|
content="This is another valid document for testing resilience.",
|
||||||
|
metadata={"status": "valid"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="valid document",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "valid document" in content_text
|
||||||
|
|
||||||
|
|
||||||
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||||
assert len(providers) > 0
|
assert len(providers) > 0
|
||||||
|
@ -249,3 +353,107 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
||||||
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_query_generation_db"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="ai-doc",
|
||||||
|
content="Artificial intelligence and machine learning are transforming technology.",
|
||||||
|
metadata={"category": "AI"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="banana-doc",
|
||||||
|
content="Don't bring a banana to a knife fight.",
|
||||||
|
metadata={"category": "wisdom"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="Tell me about AI",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "artificial intelligence" in content_text or "machine learning" in content_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_pdf_data_url_db"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
sample_pdf = b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
|
||||||
|
pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="test-pdf-data-url",
|
||||||
|
content=pdf_data_url,
|
||||||
|
metadata={"type": "pdf", "source": "data_url"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
files_list = client_with_empty_registry.files.list()
|
||||||
|
assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
|
||||||
|
|
||||||
|
pdf_file = None
|
||||||
|
for file in files_list.data:
|
||||||
|
if file.filename and "test-pdf-data-url" in file.filename:
|
||||||
|
pdf_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
assert pdf_file is not None, "PDF file should be found in Files API"
|
||||||
|
assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
|
||||||
|
|
||||||
|
file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
|
||||||
|
assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
|
||||||
|
|
||||||
|
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
|
||||||
|
vector_store_id=actual_vector_db_id
|
||||||
|
)
|
||||||
|
assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="sample title",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "sample title" in content_text or "title" in content_text
|
||||||
|
|
|
@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
|
||||||
# Should raise an exception instead of returning empty string
|
# Should raise an exception instead of returning empty string
|
||||||
with pytest.raises(UnicodeDecodeError):
|
with pytest.raises(UnicodeDecodeError):
|
||||||
content_from_data_and_mime_type(data, mime_type)
|
content_from_data_and_mime_type(data, mime_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_memory_tool_error_handling():
|
||||||
|
"""Test that memory tool handles various failures gracefully without crashing."""
|
||||||
|
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
|
||||||
|
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
config = RagToolRuntimeConfig()
|
||||||
|
memory_tool = MemoryToolRuntimeImpl(
|
||||||
|
config=config,
|
||||||
|
vector_io_api=AsyncMock(),
|
||||||
|
inference_api=AsyncMock(),
|
||||||
|
files_api=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
|
||||||
|
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
|
||||||
|
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_file1 = MagicMock()
|
||||||
|
mock_file1.id = "file_good1"
|
||||||
|
mock_file2 = MagicMock()
|
||||||
|
mock_file2.id = "file_good2"
|
||||||
|
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_instance = AsyncMock()
|
||||||
|
mock_instance.get.side_effect = Exception("Bad URL")
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||||
|
|
||||||
|
# won't raise exception despite one document failing
|
||||||
|
await memory_tool.insert(docs, "vector_store_123")
|
||||||
|
|
||||||
|
# processed 2 documents successfully, skipped 1
|
||||||
|
assert memory_tool.files_api.openai_upload_file.call_count == 2
|
||||||
|
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue