diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index f4834969a..9d892c166 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -118,27 +118,25 @@ async def content_from_doc(doc: RAGDocument) -> str: if isinstance(doc.content, URL): if doc.content.uri.startswith("data:"): return content_from_data(doc.content.uri) - else: - async with httpx.AsyncClient() as client: - r = await client.get(doc.content.uri) - if doc.mime_type == "application/pdf": - return parse_pdf(r.content) - else: - return r.text - - pattern = re.compile("^(https?://|file://|data:)") - if pattern.match(doc.content): - if doc.content.startswith("data:"): - return content_from_data(doc.content) - else: + async with httpx.AsyncClient() as client: + r = await client.get(doc.content.uri) + if doc.mime_type == "application/pdf": + return parse_pdf(r.content) + return r.text + elif isinstance(doc.content, str): + pattern = re.compile("^(https?://|file://|data:)") + if pattern.match(doc.content): + if doc.content.startswith("data:"): + return content_from_data(doc.content) async with httpx.AsyncClient() as client: r = await client.get(doc.content) if doc.mime_type == "application/pdf": return parse_pdf(r.content) - else: - return r.text - - return interleaved_content_as_str(doc.content) + return r.text + return doc.content + else: + # will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent + return interleaved_content_as_str(doc.content) def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]: diff --git a/tests/unit/providers/utils/__init__.py b/tests/unit/providers/utils/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/utils/memory/__init__.py b/tests/unit/providers/utils/memory/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/utils/memory/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/utils/memory/test_vector_store.py b/tests/unit/providers/utils/memory/test_vector_store.py new file mode 100644 index 000000000..4a3c33a6b --- /dev/null +++ b/tests/unit/providers/utils/memory/test_vector_store.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from llama_stack.apis.common.content_types import URL, TextContentItem +from llama_stack.apis.tools import RAGDocument +from llama_stack.providers.utils.memory.vector_store import content_from_doc + + +@pytest.mark.asyncio +async def test_content_from_doc_with_url(): + """Test extracting content from RAGDocument with URL content.""" + mock_url = URL(uri="https://example.com") + mock_doc = RAGDocument(document_id="foo", content=mock_url) + + mock_response = MagicMock() + mock_response.text = "Sample content from URL" + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + + result = await content_from_doc(mock_doc) + + assert result == "Sample content from URL" + mock_instance.get.assert_called_once_with(mock_url.uri) + + +@pytest.mark.asyncio +async def test_content_from_doc_with_pdf_url(): + """Test extracting content from RAGDocument with URL pointing to a PDF.""" + mock_url = URL(uri="https://example.com/document.pdf") + mock_doc = RAGDocument(document_id="foo", content=mock_url, mime_type="application/pdf") + + mock_response = MagicMock() + mock_response.content = b"PDF binary data" + + with ( + patch("httpx.AsyncClient") as mock_client, + patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf, + ): + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + mock_parse_pdf.return_value = "Extracted PDF content" + + result = await content_from_doc(mock_doc) + + assert result == "Extracted PDF content" + mock_instance.get.assert_called_once_with(mock_url.uri) + mock_parse_pdf.assert_called_once_with(b"PDF binary data") + + +@pytest.mark.asyncio +async def test_content_from_doc_with_data_url(): + """Test extracting content from RAGDocument with data URL content.""" + data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded + mock_url = URL(uri=data_url) + mock_doc = RAGDocument(document_id="foo", content=mock_url) + + with patch("llama_stack.providers.utils.memory.vector_store.content_from_data") as mock_content_from_data: + mock_content_from_data.return_value = "Hello World" + + result = await content_from_doc(mock_doc) + + assert result == "Hello World" + mock_content_from_data.assert_called_once_with(data_url) + + +@pytest.mark.asyncio +async def test_content_from_doc_with_string(): + """Test extracting content from RAGDocument with string content.""" + content_string = "This is plain text content" + mock_doc = RAGDocument(document_id="foo", content=content_string) + + result = await content_from_doc(mock_doc) + + assert result == content_string + + +@pytest.mark.asyncio +async def test_content_from_doc_with_string_url(): + """Test extracting content from RAGDocument with string URL content.""" + url_string = "https://example.com" + mock_doc = RAGDocument(document_id="foo", content=url_string) + + mock_response = MagicMock() + mock_response.text = "Sample content from URL string" + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + + result = await content_from_doc(mock_doc) + + assert result == "Sample content from URL string" + mock_instance.get.assert_called_once_with(url_string) + + +@pytest.mark.asyncio +async def test_content_from_doc_with_string_pdf_url(): + """Test extracting content from RAGDocument with string URL pointing to a PDF.""" + url_string = "https://example.com/document.pdf" + mock_doc = RAGDocument(document_id="foo", content=url_string, mime_type="application/pdf") + + mock_response = MagicMock() + mock_response.content = b"PDF binary data" + + with ( + patch("httpx.AsyncClient") as mock_client, + patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf, + ): + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + mock_parse_pdf.return_value = "Extracted PDF content from string URL" + + result = await content_from_doc(mock_doc) + + assert result == "Extracted PDF content from string URL" + mock_instance.get.assert_called_once_with(url_string) + mock_parse_pdf.assert_called_once_with(b"PDF binary data") + + +@pytest.mark.asyncio +async def test_content_from_doc_with_interleaved_content(): + """Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit).""" + interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")] + mock_doc = RAGDocument(document_id="foo", content=interleaved_content) + + with patch("llama_stack.providers.utils.memory.vector_store.interleaved_content_as_str") as mock_interleaved: + mock_interleaved.return_value = "First item\nSecond item" + + result = await content_from_doc(mock_doc) + + assert result == "First item\nSecond item" + mock_interleaved.assert_called_once_with(interleaved_content)