forked from phoenix-oss/llama-stack-mirror
fix: add check for interleavedContent (#1973)
# What does this PR do? Checks for RAGDocument of type InterleavedContent I noticed when stepping through the code that the supported types for `RAGDocument` included `InterleavedContent` as a content type. This type is not checked against before putting the `doc.content` is regex matched against. This would cause a runtime error. This change adds an explicit check for type. The only other part that I'm unclear on is how to handle the `ImageContent` type since this would always just return `<image>` which seems like an undesired behavior. Should the `InterleavedContent` type be removed from `RAGDocument` and replaced with `URI | str`? ## Test Plan [//]: # (## Documentation) --------- Signed-off-by: Kevin <kpostlet@redhat.com>
This commit is contained in:
parent
1a529705da
commit
a57985eeac
4 changed files with 170 additions and 17 deletions
|
@ -118,27 +118,25 @@ async def content_from_doc(doc: RAGDocument) -> str:
|
|||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
return content_from_data(doc.content.uri)
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
if doc.mime_type == "application/pdf":
|
||||
return parse_pdf(r.content)
|
||||
else:
|
||||
return r.text
|
||||
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
if pattern.match(doc.content):
|
||||
if doc.content.startswith("data:"):
|
||||
return content_from_data(doc.content)
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
if doc.mime_type == "application/pdf":
|
||||
return parse_pdf(r.content)
|
||||
return r.text
|
||||
elif isinstance(doc.content, str):
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
if pattern.match(doc.content):
|
||||
if doc.content.startswith("data:"):
|
||||
return content_from_data(doc.content)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content)
|
||||
if doc.mime_type == "application/pdf":
|
||||
return parse_pdf(r.content)
|
||||
else:
|
||||
return r.text
|
||||
|
||||
return interleaved_content_as_str(doc.content)
|
||||
return r.text
|
||||
return doc.content
|
||||
else:
|
||||
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
|
||||
return interleaved_content_as_str(doc.content)
|
||||
|
||||
|
||||
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]:
|
||||
|
|
5
tests/unit/providers/utils/__init__.py
Normal file
5
tests/unit/providers/utils/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
tests/unit/providers/utils/memory/__init__.py
Normal file
5
tests/unit/providers/utils/memory/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
145
tests/unit/providers/utils/memory/test_vector_store.py
Normal file
145
tests/unit/providers/utils/memory/test_vector_store.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, TextContentItem
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_doc
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_url():
|
||||
"""Test extracting content from RAGDocument with URL content."""
|
||||
mock_url = URL(uri="https://example.com")
|
||||
mock_doc = RAGDocument(document_id="foo", content=mock_url)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Sample content from URL"
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Sample content from URL"
|
||||
mock_instance.get.assert_called_once_with(mock_url.uri)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_pdf_url():
|
||||
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
|
||||
mock_url = URL(uri="https://example.com/document.pdf")
|
||||
mock_doc = RAGDocument(document_id="foo", content=mock_url, mime_type="application/pdf")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"PDF binary data"
|
||||
|
||||
with (
|
||||
patch("httpx.AsyncClient") as mock_client,
|
||||
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
|
||||
):
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
mock_parse_pdf.return_value = "Extracted PDF content"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Extracted PDF content"
|
||||
mock_instance.get.assert_called_once_with(mock_url.uri)
|
||||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_data_url():
|
||||
"""Test extracting content from RAGDocument with data URL content."""
|
||||
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
|
||||
mock_url = URL(uri=data_url)
|
||||
mock_doc = RAGDocument(document_id="foo", content=mock_url)
|
||||
|
||||
with patch("llama_stack.providers.utils.memory.vector_store.content_from_data") as mock_content_from_data:
|
||||
mock_content_from_data.return_value = "Hello World"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Hello World"
|
||||
mock_content_from_data.assert_called_once_with(data_url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_string():
|
||||
"""Test extracting content from RAGDocument with string content."""
|
||||
content_string = "This is plain text content"
|
||||
mock_doc = RAGDocument(document_id="foo", content=content_string)
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == content_string
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_string_url():
|
||||
"""Test extracting content from RAGDocument with string URL content."""
|
||||
url_string = "https://example.com"
|
||||
mock_doc = RAGDocument(document_id="foo", content=url_string)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Sample content from URL string"
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Sample content from URL string"
|
||||
mock_instance.get.assert_called_once_with(url_string)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_string_pdf_url():
|
||||
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
|
||||
url_string = "https://example.com/document.pdf"
|
||||
mock_doc = RAGDocument(document_id="foo", content=url_string, mime_type="application/pdf")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"PDF binary data"
|
||||
|
||||
with (
|
||||
patch("httpx.AsyncClient") as mock_client,
|
||||
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
|
||||
):
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
mock_parse_pdf.return_value = "Extracted PDF content from string URL"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Extracted PDF content from string URL"
|
||||
mock_instance.get.assert_called_once_with(url_string)
|
||||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_from_doc_with_interleaved_content():
|
||||
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
|
||||
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]
|
||||
mock_doc = RAGDocument(document_id="foo", content=interleaved_content)
|
||||
|
||||
with patch("llama_stack.providers.utils.memory.vector_store.interleaved_content_as_str") as mock_interleaved:
|
||||
mock_interleaved.return_value = "First item\nSecond item"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "First item\nSecond item"
|
||||
mock_interleaved.assert_called_once_with(interleaved_content)
|
Loading…
Add table
Add a link
Reference in a new issue