From 5f49dce8392c89c2ec09c6b04cddb40e55f49c3c Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 12 Sep 2024 10:54:55 -0700 Subject: [PATCH] support data: in URL for memory. Add ootb support for pdfs --- llama_toolchain/agentic_system/client.py | 7 ++- llama_toolchain/memory/api/api.py | 21 +++++++- llama_toolchain/memory/client.py | 15 +++++- llama_toolchain/memory/common/vector_store.py | 49 +++++++++++++++++-- llama_toolchain/memory/providers.py | 2 + 5 files changed, 82 insertions(+), 12 deletions(-) diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index b47e402f0..e30e90376 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -83,13 +83,12 @@ class AgenticSystemClient(AgenticSystem): if line.startswith("data:"): data = line[len("data: ") :] try: - if "error" in data: + jdata = json.loads(data) + if "error" in jdata: cprint(data, "red") continue - yield AgenticSystemTurnResponseStreamChunk( - **json.loads(data) - ) + yield AgenticSystemTurnResponseStreamChunk(**jdata) except Exception as e: print(data) print(f"Error with parsing or validation: {e}") diff --git a/llama_toolchain/memory/api/api.py b/llama_toolchain/memory/api/api.py index 70c7aa7ec..a21484fc0 100644 --- a/llama_toolchain/memory/api/api.py +++ b/llama_toolchain/memory/api/api.py @@ -8,7 +8,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import base64 +import mimetypes +import os from typing import List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod @@ -23,10 +25,25 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 class MemoryBankDocument(BaseModel): document_id: str content: InterleavedTextMedia | URL - mime_type: str + mime_type: str | None = None metadata: Dict[str, Any] = Field(default_factory=dict) +def data_url_from_file(file_path: str) -> URL: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return URL(uri=data_url) + + @json_schema_type class MemoryBankType(Enum): vector = "vector" diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index 4401276fa..96263fed6 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -5,11 +5,14 @@ # the root directory of this source tree. import asyncio +import json +from pathlib import Path from typing import Any, Dict, List, Optional import fire import httpx +from termcolor import cprint from llama_toolchain.core.datatypes import RemoteProviderConfig @@ -120,7 +123,7 @@ async def run_main(host: str, port: int, stream: bool): overlap_size_in_tokens=64, ), ) - print(bank) + cprint(json.dumps(bank.dict(), indent=4), "green") retrieved_bank = await client.get_memory_bank(bank.bank_id) assert retrieved_bank is not None @@ -145,6 +148,16 @@ async def run_main(host: str, port: int, stream: bool): for i, url in enumerate(urls) ] + this_dir = os.path.dirname(__file__) + files = [Path(this_dir).parent.parent / "CONTRIBUTING.md"] + documents += [ + MemoryBankDocument( + document_id=f"num-{i}", + content=data_url_from_file(path), + ) + for i, path in enumerate(files) + ] + # insert some documents await client.insert_documents( bank_id=bank.bank_id, diff --git a/llama_toolchain/memory/common/vector_store.py b/llama_toolchain/memory/common/vector_store.py index 154deea18..a9f0b8020 100644 --- a/llama_toolchain/memory/common/vector_store.py +++ b/llama_toolchain/memory/common/vector_store.py @@ -3,21 +3,23 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import io +import re from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional +import chardet import httpx import numpy as np from numpy.typing import NDArray +from pypdf import PdfReader from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer from llama_toolchain.memory.api import * # noqa: F403 - ALL_MINILM_L6_V2_DIMENSION = 384 EMBEDDING_MODEL = None @@ -36,11 +38,48 @@ def get_embedding_model() -> "SentenceTransformer": return EMBEDDING_MODEL +def content_from_data(data_url: str) -> str: + match = re.match(r"data:([^;,]+)(?:;charset=([^;,]+))?(?:;base64)?,(.+)", data_url) + if not match: + raise ValueError("Invalid Data URL format") + + mime_type, charset, data = match.groups() + + if ";base64," in data_url: + data = base64.b64decode(data) + else: + data = data.encode("utf-8") + + mime_category = mime_type.split("/")[0] + + if mime_category == "text": + # For text-based files (including CSV, MD) + if charset: + return data.decode(charset) + else: + # Try to detect encoding if charset is not specified + detected = chardet.detect(data) + return data.decode(detected["encoding"]) + + elif mime_type == "application/pdf": + # For PDF and DOC/DOCX files, we can't reliably convert to string) + pdf_bytes = io.BytesIO(data) + pdf_reader = PdfReader(pdf_bytes) + return "\n".join([page.extract_text() for page in pdf_reader.pages]) + + else: + cprint("Could not extract content from data_url properly.", color="red") + return "" + + async def content_from_doc(doc: MemoryBankDocument) -> str: if isinstance(doc.content, URL): - async with httpx.AsyncClient() as client: - r = await client.get(doc.content.uri) - return r.text + if doc.content.uri.startswith("data:"): + return content_from_data(doc.content.uri) + else: + async with httpx.AsyncClient() as client: + r = await client.get(doc.content.uri) + return r.text return interleaved_text_media_as_str(doc.content) diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index cc113d132..809aac60e 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -10,6 +10,8 @@ from llama_toolchain.core.datatypes import * # noqa: F403 EMBEDDING_DEPS = [ "blobfile", + "chardet", + "PdfReader", "sentence-transformers", ]