From 487e16dc3f1e6cf7e52ac2ff690f7e30de16deef Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 12 Sep 2024 11:58:04 -0700 Subject: [PATCH] moved utility to common and updated data_url parsing logic --- llama_toolchain/memory/api/api.py | 18 -------- llama_toolchain/memory/client.py | 1 + llama_toolchain/memory/common/file_utils.py | 26 +++++++++++ llama_toolchain/memory/common/vector_store.py | 45 ++++++++++++++----- 4 files changed, 60 insertions(+), 30 deletions(-) create mode 100644 llama_toolchain/memory/common/file_utils.py diff --git a/llama_toolchain/memory/api/api.py b/llama_toolchain/memory/api/api.py index a21484fc0..a26ff67ea 100644 --- a/llama_toolchain/memory/api/api.py +++ b/llama_toolchain/memory/api/api.py @@ -8,9 +8,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 -import mimetypes -import os from typing import List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod @@ -29,21 +26,6 @@ class MemoryBankDocument(BaseModel): metadata: Dict[str, Any] = Field(default_factory=dict) -def data_url_from_file(file_path: str) -> URL: - if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found: {file_path}") - - with open(file_path, "rb") as file: - file_content = file.read() - - base64_content = base64.b64encode(file_content).decode("utf-8") - mime_type, _ = mimetypes.guess_type(file_path) - - data_url = f"data:{mime_type};base64,{base64_content}" - - return URL(uri=data_url) - - @json_schema_type class MemoryBankType(Enum): vector = "vector" diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index 96263fed6..5f74219da 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -17,6 +17,7 @@ from termcolor import cprint from llama_toolchain.core.datatypes import RemoteProviderConfig from .api import * # noqa: F403 +from .common.file_utils import data_url_from_file async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory: diff --git a/llama_toolchain/memory/common/file_utils.py b/llama_toolchain/memory/common/file_utils.py new file mode 100644 index 000000000..bc4462fa0 --- /dev/null +++ b/llama_toolchain/memory/common/file_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import mimetypes +import os + +from llama_models.llama3.api.datatypes import URL + + +def data_url_from_file(file_path: str) -> URL: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return URL(uri=data_url) diff --git a/llama_toolchain/memory/common/vector_store.py b/llama_toolchain/memory/common/vector_store.py index a9f0b8020..baa3fbf21 100644 --- a/llama_toolchain/memory/common/vector_store.py +++ b/llama_toolchain/memory/common/vector_store.py @@ -3,11 +3,13 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 import io import re from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional +from urllib.parse import unquote import chardet import httpx @@ -38,28 +40,47 @@ def get_embedding_model() -> "SentenceTransformer": return EMBEDDING_MODEL -def content_from_data(data_url: str) -> str: - match = re.match(r"data:([^;,]+)(?:;charset=([^;,]+))?(?:;base64)?,(.+)", data_url) +def parse_data_url(data_url: str): + data_url_pattern = re.compile( + r"^" + r"data:" + r"(?P[\w/\-+.]+)" + r"(?P;charset=(?P[\w-]+))?" + r"(?P;base64)?" + r",(?P.*)" + r"$", + re.DOTALL, + ) + match = data_url_pattern.match(data_url) if not match: raise ValueError("Invalid Data URL format") - mime_type, charset, data = match.groups() + parts = match.groupdict() + parts["is_base64"] = bool(parts["base64"]) + return parts - if ";base64," in data_url: + +def content_from_data(data_url: str) -> str: + parts = parse_data_url(data_url) + data = parts["data"] + + if parts["is_base64"]: data = base64.b64decode(data) else: - data = data.encode("utf-8") + data = unquote(data) + encoding = parts["encoding"] or "utf-8" + data = data.encode(encoding) + encoding = parts["encoding"] + if not encoding: + detected = chardet.detect(data) + encoding = detected["encoding"] + + mime_type = parts["mimetype"] mime_category = mime_type.split("/")[0] - if mime_category == "text": # For text-based files (including CSV, MD) - if charset: - return data.decode(charset) - else: - # Try to detect encoding if charset is not specified - detected = chardet.detect(data) - return data.decode(detected["encoding"]) + return data.decode(encoding) elif mime_type == "application/pdf": # For PDF and DOC/DOCX files, we can't reliably convert to string)