fix localfs update

This commit is contained in:
Dinesh Yeduguru 2024-11-27 14:51:51 -08:00
parent 32fbe366d7
commit 3cb8b33290

View file

@ -9,6 +9,8 @@ import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
import base64
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from urllib.parse import urlparse from urllib.parse import urlparse
@ -139,19 +141,31 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl.load() dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat( dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True [dataset_impl.df, new_rows_df], ignore_index=True
) )
url = str(dataset_info.dataset_def.url) url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url) parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme: if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False) dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
"utf-8"
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// URLs are supported for writing." f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
) )