mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
[#342] RAG - fix PDF format in vector database
This commit is contained in:
parent
2fc1c16d58
commit
da035d69cf
3 changed files with 94 additions and 4 deletions
BIN
llama_stack/providers/tests/memory/fixtures/dummy.pdf
Normal file
BIN
llama_stack/providers/tests/memory/fixtures/dummy.pdf
Normal file
Binary file not shown.
80
llama_stack/providers/tests/memory/test_vector_store.py
Normal file
80
llama_stack/providers/tests/memory/test_vector_store.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_httpx import HTTPXMock
|
||||||
|
|
||||||
|
from llama_stack.apis.memory.memory import MemoryBankDocument, URL
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import content_from_doc
|
||||||
|
|
||||||
|
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(file_path: str) -> bytes:
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
|
||||||
|
def data_url_from_file(file_path: str) -> str:
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
file_content = file.read()
|
||||||
|
|
||||||
|
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
|
||||||
|
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||||
|
|
||||||
|
return data_url
|
||||||
|
|
||||||
|
|
||||||
|
# Requires pytest-httpx - pip install pytest-httpx
|
||||||
|
class TestVectorStore:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_content_from_pdf_data_uri(self):
|
||||||
|
data_uri = data_url_from_file(DUMMY_PDF_PATH)
|
||||||
|
doc = MemoryBankDocument(
|
||||||
|
document_id="dummy",
|
||||||
|
content=data_uri,
|
||||||
|
mime_type="application/pdf",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
assert content == "Dummy PDF file"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_downloads_pdf_and_returns_content(self, httpx_mock: HTTPXMock):
|
||||||
|
url = "https://example.com/dummy.pdf"
|
||||||
|
httpx_mock.add_response(url=url, content=read_file(DUMMY_PDF_PATH))
|
||||||
|
doc = MemoryBankDocument(
|
||||||
|
document_id="dummy",
|
||||||
|
content=url,
|
||||||
|
mime_type="application/pdf",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
assert content == "Dummy PDF file"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_downloads_pdf_and_returns_content_with_url_object(
|
||||||
|
self, httpx_mock: HTTPXMock
|
||||||
|
):
|
||||||
|
url = "https://example.com/dummy.pdf"
|
||||||
|
httpx_mock.add_response(url=url, content=read_file(DUMMY_PDF_PATH))
|
||||||
|
doc = MemoryBankDocument(
|
||||||
|
document_id="dummy",
|
||||||
|
content=URL(
|
||||||
|
uri=url,
|
||||||
|
),
|
||||||
|
mime_type="application/pdf",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
assert content == "Dummy PDF file"
|
|
@ -45,6 +45,13 @@ def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
|
|
||||||
|
def parse_pdf(data: bytes) -> str:
|
||||||
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
pdf_bytes = io.BytesIO(data)
|
||||||
|
pdf_reader = PdfReader(pdf_bytes)
|
||||||
|
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
||||||
|
|
||||||
|
|
||||||
def parse_data_url(data_url: str):
|
def parse_data_url(data_url: str):
|
||||||
data_url_pattern = re.compile(
|
data_url_pattern = re.compile(
|
||||||
r"^"
|
r"^"
|
||||||
|
@ -88,10 +95,7 @@ def content_from_data(data_url: str) -> str:
|
||||||
return data.decode(encoding)
|
return data.decode(encoding)
|
||||||
|
|
||||||
elif mime_type == "application/pdf":
|
elif mime_type == "application/pdf":
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string)
|
return parse_pdf(data)
|
||||||
pdf_bytes = io.BytesIO(data)
|
|
||||||
pdf_reader = PdfReader(pdf_bytes)
|
|
||||||
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log.error("Could not extract content from data_url properly.")
|
log.error("Could not extract content from data_url properly.")
|
||||||
|
@ -105,6 +109,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
else:
|
else:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(doc.content.uri)
|
r = await client.get(doc.content.uri)
|
||||||
|
if doc.mime_type == "application/pdf":
|
||||||
|
return parse_pdf(r.content)
|
||||||
|
else:
|
||||||
return r.text
|
return r.text
|
||||||
|
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
@ -114,6 +121,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
else:
|
else:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(doc.content)
|
r = await client.get(doc.content)
|
||||||
|
if doc.mime_type == "application/pdf":
|
||||||
|
return parse_pdf(r.content)
|
||||||
|
else:
|
||||||
return r.text
|
return r.text
|
||||||
|
|
||||||
return interleaved_text_media_as_str(doc.content)
|
return interleaved_text_media_as_str(doc.content)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue