diff --git a/llama_stack/providers/tests/memory/fixtures/dummy.pdf b/llama_stack/providers/tests/memory/fixtures/dummy.pdf new file mode 100644 index 000000000..774c2ea70 Binary files /dev/null and b/llama_stack/providers/tests/memory/fixtures/dummy.pdf differ diff --git a/llama_stack/providers/tests/memory/test_vector_store.py b/llama_stack/providers/tests/memory/test_vector_store.py new file mode 100644 index 000000000..1ad7abf0c --- /dev/null +++ b/llama_stack/providers/tests/memory/test_vector_store.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import mimetypes +import os +from pathlib import Path + +import pytest + +from llama_stack.apis.memory.memory import MemoryBankDocument, URL +from llama_stack.providers.utils.memory.vector_store import content_from_doc + +DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" + + +def read_file(file_path: str) -> bytes: + with open(file_path, "rb") as file: + return file.read() + + +def data_url_from_file(file_path: str) -> str: + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + +class TestVectorStore: + @pytest.mark.asyncio + async def test_returns_content_from_pdf_data_uri(self): + data_uri = data_url_from_file(DUMMY_PDF_PATH) + doc = MemoryBankDocument( + document_id="dummy", + content=data_uri, + mime_type="application/pdf", + metadata={}, + ) + content = await content_from_doc(doc) + assert content == "Dummy PDF file" + + @pytest.mark.asyncio + async def test_downloads_pdf_and_returns_content(self): + # Using GitHub to host the PDF file + url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" + doc = MemoryBankDocument( + document_id="dummy", + content=url, + mime_type="application/pdf", + metadata={}, + ) + content = await content_from_doc(doc) + assert content == "Dummy PDF file" + + @pytest.mark.asyncio + async def test_downloads_pdf_and_returns_content_with_url_object(self): + # Using GitHub to host the PDF file + url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" + doc = MemoryBankDocument( + document_id="dummy", + content=URL( + uri=url, + ), + mime_type="application/pdf", + metadata={}, + ) + content = await content_from_doc(doc) + assert content == "Dummy PDF file" diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 48cb8a99d..eb83aa671 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -45,6 +45,13 @@ def get_embedding_model(model: str) -> "SentenceTransformer": return loaded_model +def parse_pdf(data: bytes) -> str: + # For PDF and DOC/DOCX files, we can't reliably convert to string + pdf_bytes = io.BytesIO(data) + pdf_reader = PdfReader(pdf_bytes) + return "\n".join([page.extract_text() for page in pdf_reader.pages]) + + def parse_data_url(data_url: str): data_url_pattern = re.compile( r"^" @@ -88,10 +95,7 @@ def content_from_data(data_url: str) -> str: return data.decode(encoding) elif mime_type == "application/pdf": - # For PDF and DOC/DOCX files, we can't reliably convert to string) - pdf_bytes = io.BytesIO(data) - pdf_reader = PdfReader(pdf_bytes) - return "\n".join([page.extract_text() for page in pdf_reader.pages]) + return parse_pdf(data) else: log.error("Could not extract content from data_url properly.") @@ -105,6 +109,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str: else: async with httpx.AsyncClient() as client: r = await client.get(doc.content.uri) + if doc.mime_type == "application/pdf": + return parse_pdf(r.content) + else: return r.text pattern = re.compile("^(https?://|file://|data:)") @@ -114,6 +121,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str: else: async with httpx.AsyncClient() as client: r = await client.get(doc.content) + if doc.mime_type == "application/pdf": + return parse_pdf(r.content) + else: return r.text return interleaved_text_media_as_str(doc.content)