mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 04:34:57 +00:00
support data: in URL for memory. Add ootb support for pdfs
This commit is contained in:
parent
a11d92601b
commit
5f49dce839
5 changed files with 82 additions and 12 deletions
|
@ -3,21 +3,23 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import io
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chardet
|
||||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from pypdf import PdfReader
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
|
||||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODEL = None
|
||||
|
@ -36,11 +38,48 @@ def get_embedding_model() -> "SentenceTransformer":
|
|||
return EMBEDDING_MODEL
|
||||
|
||||
|
||||
def content_from_data(data_url: str) -> str:
|
||||
match = re.match(r"data:([^;,]+)(?:;charset=([^;,]+))?(?:;base64)?,(.+)", data_url)
|
||||
if not match:
|
||||
raise ValueError("Invalid Data URL format")
|
||||
|
||||
mime_type, charset, data = match.groups()
|
||||
|
||||
if ";base64," in data_url:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = data.encode("utf-8")
|
||||
|
||||
mime_category = mime_type.split("/")[0]
|
||||
|
||||
if mime_category == "text":
|
||||
# For text-based files (including CSV, MD)
|
||||
if charset:
|
||||
return data.decode(charset)
|
||||
else:
|
||||
# Try to detect encoding if charset is not specified
|
||||
detected = chardet.detect(data)
|
||||
return data.decode(detected["encoding"])
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string)
|
||||
pdf_bytes = io.BytesIO(data)
|
||||
pdf_reader = PdfReader(pdf_bytes)
|
||||
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
||||
|
||||
else:
|
||||
cprint("Could not extract content from data_url properly.", color="red")
|
||||
return ""
|
||||
|
||||
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
if doc.content.uri.startswith("data:"):
|
||||
return content_from_data(doc.content.uri)
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
|
||||
return interleaved_text_media_as_str(doc.content)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue