forked from phoenix-oss/llama-stack-mirror
Support data:
in URL for memory. Add ootb support for pdfs (#67)
* support data: in URL for memory. Add ootb support for pdfs * moved utility to common and updated data_url parsing logic --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
parent
736092f6bc
commit
1d0e91d802
6 changed files with 112 additions and 12 deletions
|
@ -83,13 +83,12 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
if "error" in data:
|
jdata = json.loads(data)
|
||||||
|
if "error" in jdata:
|
||||||
cprint(data, "red")
|
cprint(data, "red")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgenticSystemTurnResponseStreamChunk(**jdata)
|
||||||
**json.loads(data)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -23,7 +22,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
class MemoryBankDocument(BaseModel):
|
class MemoryBankDocument(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
content: InterleavedTextMedia | URL
|
content: InterleavedTextMedia | URL
|
||||||
mime_type: str
|
mime_type: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,15 +5,19 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
from .common.file_utils import data_url_from_file
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
|
||||||
|
@ -120,7 +124,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
print(bank)
|
cprint(json.dumps(bank.dict(), indent=4), "green")
|
||||||
|
|
||||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||||
assert retrieved_bank is not None
|
assert retrieved_bank is not None
|
||||||
|
@ -145,6 +149,16 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
this_dir = os.path.dirname(__file__)
|
||||||
|
files = [Path(this_dir).parent.parent / "CONTRIBUTING.md"]
|
||||||
|
documents += [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=data_url_from_file(path),
|
||||||
|
)
|
||||||
|
for i, path in enumerate(files)
|
||||||
|
]
|
||||||
|
|
||||||
# insert some documents
|
# insert some documents
|
||||||
await client.insert_documents(
|
await client.insert_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.bank_id,
|
||||||
|
|
26
llama_toolchain/memory/common/file_utils.py
Normal file
26
llama_toolchain/memory/common/file_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
|
|
||||||
|
def data_url_from_file(file_path: str) -> URL:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
file_content = file.read()
|
||||||
|
|
||||||
|
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
|
||||||
|
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||||
|
|
||||||
|
return URL(uri=data_url)
|
|
@ -3,21 +3,25 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
|
import chardet
|
||||||
import httpx
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
from pypdf import PdfReader
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_toolchain.memory.api import * # noqa: F403
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||||
|
|
||||||
EMBEDDING_MODEL = None
|
EMBEDDING_MODEL = None
|
||||||
|
@ -36,11 +40,67 @@ def get_embedding_model() -> "SentenceTransformer":
|
||||||
return EMBEDDING_MODEL
|
return EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def parse_data_url(data_url: str):
|
||||||
|
data_url_pattern = re.compile(
|
||||||
|
r"^"
|
||||||
|
r"data:"
|
||||||
|
r"(?P<mimetype>[\w/\-+.]+)"
|
||||||
|
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
|
||||||
|
r"(?P<base64>;base64)?"
|
||||||
|
r",(?P<data>.*)"
|
||||||
|
r"$",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
match = data_url_pattern.match(data_url)
|
||||||
|
if not match:
|
||||||
|
raise ValueError("Invalid Data URL format")
|
||||||
|
|
||||||
|
parts = match.groupdict()
|
||||||
|
parts["is_base64"] = bool(parts["base64"])
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def content_from_data(data_url: str) -> str:
|
||||||
|
parts = parse_data_url(data_url)
|
||||||
|
data = parts["data"]
|
||||||
|
|
||||||
|
if parts["is_base64"]:
|
||||||
|
data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
data = unquote(data)
|
||||||
|
encoding = parts["encoding"] or "utf-8"
|
||||||
|
data = data.encode(encoding)
|
||||||
|
|
||||||
|
encoding = parts["encoding"]
|
||||||
|
if not encoding:
|
||||||
|
detected = chardet.detect(data)
|
||||||
|
encoding = detected["encoding"]
|
||||||
|
|
||||||
|
mime_type = parts["mimetype"]
|
||||||
|
mime_category = mime_type.split("/")[0]
|
||||||
|
if mime_category == "text":
|
||||||
|
# For text-based files (including CSV, MD)
|
||||||
|
return data.decode(encoding)
|
||||||
|
|
||||||
|
elif mime_type == "application/pdf":
|
||||||
|
# For PDF and DOC/DOCX files, we can't reliably convert to string)
|
||||||
|
pdf_bytes = io.BytesIO(data)
|
||||||
|
pdf_reader = PdfReader(pdf_bytes)
|
||||||
|
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
||||||
|
|
||||||
|
else:
|
||||||
|
cprint("Could not extract content from data_url properly.", color="red")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
async with httpx.AsyncClient() as client:
|
if doc.content.uri.startswith("data:"):
|
||||||
r = await client.get(doc.content.uri)
|
return content_from_data(doc.content.uri)
|
||||||
return r.text
|
else:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(doc.content.uri)
|
||||||
|
return r.text
|
||||||
|
|
||||||
return interleaved_text_media_as_str(doc.content)
|
return interleaved_text_media_as_str(doc.content)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,8 @@ from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
EMBEDDING_DEPS = [
|
EMBEDDING_DEPS = [
|
||||||
"blobfile",
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"PdfReader",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue