support data: in URL for memory. Add ootb support for pdfs

This commit is contained in:
Hardik Shah 2024-09-12 10:54:55 -07:00
parent a11d92601b
commit 5f49dce839
5 changed files with 82 additions and 12 deletions

View file

@ -3,21 +3,23 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import chardet
import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
@ -36,11 +38,48 @@ def get_embedding_model() -> "SentenceTransformer":
return EMBEDDING_MODEL
def content_from_data(data_url: str) -> str:
match = re.match(r"data:([^;,]+)(?:;charset=([^;,]+))?(?:;base64)?,(.+)", data_url)
if not match:
raise ValueError("Invalid Data URL format")
mime_type, charset, data = match.groups()
if ";base64," in data_url:
data = base64.b64decode(data)
else:
data = data.encode("utf-8")
mime_category = mime_type.split("/")[0]
if mime_category == "text":
# For text-based files (including CSV, MD)
if charset:
return data.decode(charset)
else:
# Try to detect encoding if charset is not specified
detected = chardet.detect(data)
return data.decode(detected["encoding"])
elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string)
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
else:
cprint("Could not extract content from data_url properly.", color="red")
return ""
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
return interleaved_text_media_as_str(doc.content)