moved utility to common and updated data_url parsing logic

This commit is contained in:
Hardik Shah 2024-09-12 11:58:04 -07:00
parent 5f49dce839
commit 487e16dc3f
4 changed files with 60 additions and 30 deletions

View file

@ -8,9 +8,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import mimetypes
import os
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -29,21 +26,6 @@ class MemoryBankDocument(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict)
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return URL(uri=data_url)
@json_schema_type @json_schema_type
class MemoryBankType(Enum): class MemoryBankType(Enum):
vector = "vector" vector = "vector"

View file

@ -17,6 +17,7 @@ from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_toolchain.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .api import * # noqa: F403
from .common.file_utils import data_url_from_file
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from llama_models.llama3.api.datatypes import URL
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return URL(uri=data_url)

View file

@ -3,11 +3,13 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import io import io
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import unquote
import chardet import chardet
import httpx import httpx
@ -38,28 +40,47 @@ def get_embedding_model() -> "SentenceTransformer":
return EMBEDDING_MODEL return EMBEDDING_MODEL
def content_from_data(data_url: str) -> str: def parse_data_url(data_url: str):
match = re.match(r"data:([^;,]+)(?:;charset=([^;,]+))?(?:;base64)?,(.+)", data_url) data_url_pattern = re.compile(
r"^"
r"data:"
r"(?P<mimetype>[\w/\-+.]+)"
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
r"(?P<base64>;base64)?"
r",(?P<data>.*)"
r"$",
re.DOTALL,
)
match = data_url_pattern.match(data_url)
if not match: if not match:
raise ValueError("Invalid Data URL format") raise ValueError("Invalid Data URL format")
mime_type, charset, data = match.groups() parts = match.groupdict()
parts["is_base64"] = bool(parts["base64"])
return parts
if ";base64," in data_url:
def content_from_data(data_url: str) -> str:
parts = parse_data_url(data_url)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data) data = base64.b64decode(data)
else: else:
data = data.encode("utf-8") data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
encoding = parts["encoding"]
if not encoding:
detected = chardet.detect(data)
encoding = detected["encoding"]
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0] mime_category = mime_type.split("/")[0]
if mime_category == "text": if mime_category == "text":
# For text-based files (including CSV, MD) # For text-based files (including CSV, MD)
if charset: return data.decode(encoding)
return data.decode(charset)
else:
# Try to detect encoding if charset is not specified
detected = chardet.detect(data)
return data.decode(detected["encoding"])
elif mime_type == "application/pdf": elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string) # For PDF and DOC/DOCX files, we can't reliably convert to string)