mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
moved utility to common and updated data_url parsing logic
This commit is contained in:
parent
5f49dce839
commit
487e16dc3f
4 changed files with 60 additions and 30 deletions
|
@ -8,9 +8,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
@ -29,21 +26,6 @@ class MemoryBankDocument(BaseModel):
|
|||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> URL:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return URL(uri=data_url)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankType(Enum):
|
||||
vector = "vector"
|
||||
|
|
|
@ -17,6 +17,7 @@ from termcolor import cprint
|
|||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import * # noqa: F403
|
||||
from .common.file_utils import data_url_from_file
|
||||
|
||||
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
|
||||
|
|
26
llama_toolchain/memory/common/file_utils.py
Normal file
26
llama_toolchain/memory/common/file_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> URL:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return URL(uri=data_url)
|
|
@ -3,11 +3,13 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import httpx
|
||||
|
@ -38,28 +40,47 @@ def get_embedding_model() -> "SentenceTransformer":
|
|||
return EMBEDDING_MODEL
|
||||
|
||||
|
||||
def content_from_data(data_url: str) -> str:
|
||||
match = re.match(r"data:([^;,]+)(?:;charset=([^;,]+))?(?:;base64)?,(.+)", data_url)
|
||||
def parse_data_url(data_url: str):
|
||||
data_url_pattern = re.compile(
|
||||
r"^"
|
||||
r"data:"
|
||||
r"(?P<mimetype>[\w/\-+.]+)"
|
||||
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
|
||||
r"(?P<base64>;base64)?"
|
||||
r",(?P<data>.*)"
|
||||
r"$",
|
||||
re.DOTALL,
|
||||
)
|
||||
match = data_url_pattern.match(data_url)
|
||||
if not match:
|
||||
raise ValueError("Invalid Data URL format")
|
||||
|
||||
mime_type, charset, data = match.groups()
|
||||
parts = match.groupdict()
|
||||
parts["is_base64"] = bool(parts["base64"])
|
||||
return parts
|
||||
|
||||
if ";base64," in data_url:
|
||||
|
||||
def content_from_data(data_url: str) -> str:
|
||||
parts = parse_data_url(data_url)
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = data.encode("utf-8")
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
|
||||
encoding = parts["encoding"]
|
||||
if not encoding:
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
|
||||
if mime_category == "text":
|
||||
# For text-based files (including CSV, MD)
|
||||
if charset:
|
||||
return data.decode(charset)
|
||||
else:
|
||||
# Try to detect encoding if charset is not specified
|
||||
detected = chardet.detect(data)
|
||||
return data.decode(detected["encoding"])
|
||||
return data.decode(encoding)
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue