Support data: in URL for memory. Add ootb support for pdfs (#67)

* support data: in URL for memory. Add ootb support for pdfs

* moved utility to common and updated data_url parsing logic

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-09-12 13:00:21 -07:00 committed by GitHub
parent 736092f6bc
commit 1d0e91d802
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 112 additions and 12 deletions

View file

@ -83,13 +83,12 @@ class AgenticSystemClient(AgenticSystem):
if line.startswith("data:"): if line.startswith("data:"):
data = line[len("data: ") :] data = line[len("data: ") :]
try: try:
if "error" in data: jdata = json.loads(data)
if "error" in jdata:
cprint(data, "red") cprint(data, "red")
continue continue
yield AgenticSystemTurnResponseStreamChunk( yield AgenticSystemTurnResponseStreamChunk(**jdata)
**json.loads(data)
)
except Exception as e: except Exception as e:
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")

View file

@ -8,7 +8,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -23,7 +22,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
class MemoryBankDocument(BaseModel): class MemoryBankDocument(BaseModel):
document_id: str document_id: str
content: InterleavedTextMedia | URL content: InterleavedTextMedia | URL
mime_type: str mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict)

View file

@ -5,15 +5,19 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_toolchain.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .api import * # noqa: F403
from .common.file_utils import data_url_from_file
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
@ -120,7 +124,7 @@ async def run_main(host: str, port: int, stream: bool):
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
) )
print(bank) cprint(json.dumps(bank.dict(), indent=4), "green")
retrieved_bank = await client.get_memory_bank(bank.bank_id) retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None assert retrieved_bank is not None
@ -145,6 +149,16 @@ async def run_main(host: str, port: int, stream: bool):
for i, url in enumerate(urls) for i, url in enumerate(urls)
] ]
this_dir = os.path.dirname(__file__)
files = [Path(this_dir).parent.parent / "CONTRIBUTING.md"]
documents += [
MemoryBankDocument(
document_id=f"num-{i}",
content=data_url_from_file(path),
)
for i, path in enumerate(files)
]
# insert some documents # insert some documents
await client.insert_documents( await client.insert_documents(
bank_id=bank.bank_id, bank_id=bank.bank_id,

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from llama_models.llama3.api.datatypes import URL
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return URL(uri=data_url)

View file

@ -3,21 +3,25 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import io
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import unquote
import chardet
import httpx import httpx
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403 from llama_toolchain.memory.api import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384 ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None EMBEDDING_MODEL = None
@ -36,8 +40,64 @@ def get_embedding_model() -> "SentenceTransformer":
return EMBEDDING_MODEL return EMBEDDING_MODEL
def parse_data_url(data_url: str):
data_url_pattern = re.compile(
r"^"
r"data:"
r"(?P<mimetype>[\w/\-+.]+)"
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
r"(?P<base64>;base64)?"
r",(?P<data>.*)"
r"$",
re.DOTALL,
)
match = data_url_pattern.match(data_url)
if not match:
raise ValueError("Invalid Data URL format")
parts = match.groupdict()
parts["is_base64"] = bool(parts["base64"])
return parts
def content_from_data(data_url: str) -> str:
parts = parse_data_url(data_url)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
encoding = parts["encoding"]
if not encoding:
detected = chardet.detect(data)
encoding = detected["encoding"]
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
if mime_category == "text":
# For text-based files (including CSV, MD)
return data.decode(encoding)
elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string)
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
else:
cprint("Could not extract content from data_url properly.", color="red")
return ""
async def content_from_doc(doc: MemoryBankDocument) -> str: async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL): if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
else:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri) r = await client.get(doc.content.uri)
return r.text return r.text

View file

@ -10,6 +10,8 @@ from llama_toolchain.core.datatypes import * # noqa: F403
EMBEDDING_DEPS = [ EMBEDDING_DEPS = [
"blobfile", "blobfile",
"chardet",
"PdfReader",
"sentence-transformers", "sentence-transformers",
] ]