From 1d0e91d802b54dc4ed584c87162b906700302437 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 12 Sep 2024 13:00:21 -0700 Subject: [PATCH] Support `data:` in URL for memory. Add ootb support for pdfs (#67) * support data: in URL for memory. Add ootb support for pdfs * moved utility to common and updated data_url parsing logic --------- Co-authored-by: Hardik Shah --- llama_toolchain/agentic_system/client.py | 7 +- llama_toolchain/memory/api/api.py | 3 +- llama_toolchain/memory/client.py | 16 ++++- llama_toolchain/memory/common/file_utils.py | 26 +++++++ llama_toolchain/memory/common/vector_store.py | 70 +++++++++++++++++-- llama_toolchain/memory/providers.py | 2 + 6 files changed, 112 insertions(+), 12 deletions(-) create mode 100644 llama_toolchain/memory/common/file_utils.py diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index b47e402f0..e30e90376 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -83,13 +83,12 @@ class AgenticSystemClient(AgenticSystem): if line.startswith("data:"): data = line[len("data: ") :] try: - if "error" in data: + jdata = json.loads(data) + if "error" in jdata: cprint(data, "red") continue - yield AgenticSystemTurnResponseStreamChunk( - **json.loads(data) - ) + yield AgenticSystemTurnResponseStreamChunk(**jdata) except Exception as e: print(data) print(f"Error with parsing or validation: {e}") diff --git a/llama_toolchain/memory/api/api.py b/llama_toolchain/memory/api/api.py index 70c7aa7ec..a26ff67ea 100644 --- a/llama_toolchain/memory/api/api.py +++ b/llama_toolchain/memory/api/api.py @@ -8,7 +8,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - from typing import List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod @@ -23,7 +22,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 class MemoryBankDocument(BaseModel): document_id: str content: InterleavedTextMedia | URL - mime_type: str + mime_type: str | None = None metadata: Dict[str, Any] = Field(default_factory=dict) diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index 4401276fa..5f74219da 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -5,15 +5,19 @@ # the root directory of this source tree. import asyncio +import json +from pathlib import Path from typing import Any, Dict, List, Optional import fire import httpx +from termcolor import cprint from llama_toolchain.core.datatypes import RemoteProviderConfig from .api import * # noqa: F403 +from .common.file_utils import data_url_from_file async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory: @@ -120,7 +124,7 @@ async def run_main(host: str, port: int, stream: bool): overlap_size_in_tokens=64, ), ) - print(bank) + cprint(json.dumps(bank.dict(), indent=4), "green") retrieved_bank = await client.get_memory_bank(bank.bank_id) assert retrieved_bank is not None @@ -145,6 +149,16 @@ async def run_main(host: str, port: int, stream: bool): for i, url in enumerate(urls) ] + this_dir = os.path.dirname(__file__) + files = [Path(this_dir).parent.parent / "CONTRIBUTING.md"] + documents += [ + MemoryBankDocument( + document_id=f"num-{i}", + content=data_url_from_file(path), + ) + for i, path in enumerate(files) + ] + # insert some documents await client.insert_documents( bank_id=bank.bank_id, diff --git a/llama_toolchain/memory/common/file_utils.py b/llama_toolchain/memory/common/file_utils.py new file mode 100644 index 000000000..bc4462fa0 --- /dev/null +++ b/llama_toolchain/memory/common/file_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import mimetypes +import os + +from llama_models.llama3.api.datatypes import URL + + +def data_url_from_file(file_path: str) -> URL: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return URL(uri=data_url) diff --git a/llama_toolchain/memory/common/vector_store.py b/llama_toolchain/memory/common/vector_store.py index 154deea18..baa3fbf21 100644 --- a/llama_toolchain/memory/common/vector_store.py +++ b/llama_toolchain/memory/common/vector_store.py @@ -3,21 +3,25 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import base64 +import io +import re from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional +from urllib.parse import unquote +import chardet import httpx import numpy as np from numpy.typing import NDArray +from pypdf import PdfReader from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer from llama_toolchain.memory.api import * # noqa: F403 - ALL_MINILM_L6_V2_DIMENSION = 384 EMBEDDING_MODEL = None @@ -36,11 +40,67 @@ def get_embedding_model() -> "SentenceTransformer": return EMBEDDING_MODEL +def parse_data_url(data_url: str): + data_url_pattern = re.compile( + r"^" + r"data:" + r"(?P[\w/\-+.]+)" + r"(?P;charset=(?P[\w-]+))?" + r"(?P;base64)?" + r",(?P.*)" + r"$", + re.DOTALL, + ) + match = data_url_pattern.match(data_url) + if not match: + raise ValueError("Invalid Data URL format") + + parts = match.groupdict() + parts["is_base64"] = bool(parts["base64"]) + return parts + + +def content_from_data(data_url: str) -> str: + parts = parse_data_url(data_url) + data = parts["data"] + + if parts["is_base64"]: + data = base64.b64decode(data) + else: + data = unquote(data) + encoding = parts["encoding"] or "utf-8" + data = data.encode(encoding) + + encoding = parts["encoding"] + if not encoding: + detected = chardet.detect(data) + encoding = detected["encoding"] + + mime_type = parts["mimetype"] + mime_category = mime_type.split("/")[0] + if mime_category == "text": + # For text-based files (including CSV, MD) + return data.decode(encoding) + + elif mime_type == "application/pdf": + # For PDF and DOC/DOCX files, we can't reliably convert to string) + pdf_bytes = io.BytesIO(data) + pdf_reader = PdfReader(pdf_bytes) + return "\n".join([page.extract_text() for page in pdf_reader.pages]) + + else: + cprint("Could not extract content from data_url properly.", color="red") + return "" + + async def content_from_doc(doc: MemoryBankDocument) -> str: if isinstance(doc.content, URL): - async with httpx.AsyncClient() as client: - r = await client.get(doc.content.uri) - return r.text + if doc.content.uri.startswith("data:"): + return content_from_data(doc.content.uri) + else: + async with httpx.AsyncClient() as client: + r = await client.get(doc.content.uri) + return r.text return interleaved_text_media_as_str(doc.content) diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index 525f947a0..adfff2e71 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -10,6 +10,8 @@ from llama_toolchain.core.datatypes import * # noqa: F403 EMBEDDING_DEPS = [ "blobfile", + "chardet", + "PdfReader", "sentence-transformers", ]