mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
unregister
This commit is contained in:
parent
cf225c9710
commit
4fee3af91f
4 changed files with 79 additions and 96 deletions
|
@ -3,16 +3,10 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
|
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
|
||||||
DATASETS_PREFIX = "localfs_datasets:"
|
DATASETS_PREFIX = "localfs_datasets:"
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(ABC):
|
class PandasDataframeDataset:
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __len__(self) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetInfo:
|
|
||||||
dataset_def: Dataset
|
|
||||||
dataset_impl: BaseDataset
|
|
||||||
|
|
||||||
|
|
||||||
class PandasDataframeDataset(BaseDataset):
|
|
||||||
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.dataset_def = dataset_def
|
self.dataset_def = dataset_def
|
||||||
|
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
|
||||||
else:
|
else:
|
||||||
return self.df.iloc[idx].to_dict()
|
return self.df.iloc[idx].to_dict()
|
||||||
|
|
||||||
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
|
||||||
# note that we will drop any columns in dataset that are not in the schema
|
|
||||||
df = df[self.dataset_def.dataset_schema.keys()]
|
|
||||||
# check all columns in dataset schema are present
|
|
||||||
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
|
||||||
# TODO: type checking against column types in dataset schema
|
|
||||||
return df
|
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
if self.df is not None:
|
if self.df is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
df = get_dataframe_from_url(self.dataset_def.url)
|
if self.dataset_def.source.type == "uri":
|
||||||
if df is None:
|
self.df = get_dataframe_from_url(self.dataset_def.uri)
|
||||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
elif self.dataset_def.source.type == "rows":
|
||||||
|
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||||
|
|
||||||
self.df = self._validate_dataset_schema(df)
|
if self.df is None:
|
||||||
|
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||||
|
|
||||||
|
|
||||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
@ -99,29 +66,21 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
||||||
for dataset in stored_datasets:
|
for dataset in stored_datasets:
|
||||||
dataset = Dataset.model_validate_json(dataset)
|
dataset = Dataset.model_validate_json(dataset)
|
||||||
dataset_impl = PandasDataframeDataset(dataset)
|
self.dataset_infos[dataset.identifier] = dataset
|
||||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
|
||||||
dataset_def=dataset,
|
|
||||||
dataset_impl=dataset_impl,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
dataset_def: Dataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Store in kvstore
|
# Store in kvstore
|
||||||
key = f"{DATASETS_PREFIX}{dataset.identifier}"
|
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=dataset.json(),
|
value=dataset_def.model_dump_json(),
|
||||||
)
|
|
||||||
dataset_impl = PandasDataframeDataset(dataset)
|
|
||||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
|
||||||
dataset_def=dataset,
|
|
||||||
dataset_impl=dataset_impl,
|
|
||||||
)
|
)
|
||||||
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||||
|
@ -134,51 +93,28 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
start_index: Optional[int] = None,
|
start_index: Optional[int] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> IterrowsResponse:
|
) -> IterrowsResponse:
|
||||||
dataset_info = self.dataset_infos.get(dataset_id)
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
dataset_info.dataset_impl.load()
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
|
dataset_impl.load()
|
||||||
|
|
||||||
start_index = start_index or 0
|
start_index = start_index or 0
|
||||||
|
|
||||||
if limit is None or limit == -1:
|
if limit is None or limit == -1:
|
||||||
end = len(dataset_info.dataset_impl)
|
end = len(dataset_impl)
|
||||||
else:
|
else:
|
||||||
end = min(start_index + limit, len(dataset_info.dataset_impl))
|
end = min(start_index + limit, len(dataset_impl))
|
||||||
|
|
||||||
rows = dataset_info.dataset_impl[start_index:end]
|
rows = dataset_impl[start_index:end]
|
||||||
|
|
||||||
return IterrowsResponse(
|
return IterrowsResponse(
|
||||||
data=rows,
|
data=rows,
|
||||||
next_index=end if end < len(dataset_info.dataset_impl) else None,
|
next_index=end if end < len(dataset_impl) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
dataset_info = self.dataset_infos.get(dataset_id)
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
if dataset_info is None:
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
|
||||||
|
|
||||||
dataset_impl = dataset_info.dataset_impl
|
|
||||||
dataset_impl.load()
|
dataset_impl.load()
|
||||||
|
|
||||||
new_rows_df = pandas.DataFrame(rows)
|
new_rows_df = pandas.DataFrame(rows)
|
||||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
|
||||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||||
|
|
||||||
url = str(dataset_info.dataset_def.url.uri)
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
|
|
||||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
|
||||||
file_path = parsed_url.path
|
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
||||||
dataset_impl.df.to_csv(file_path, index=False)
|
|
||||||
elif parsed_url.scheme == "data":
|
|
||||||
# For data URLs, we need to update the base64-encoded content
|
|
||||||
if not parsed_url.path.startswith("text/csv;base64,"):
|
|
||||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
|
||||||
|
|
||||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
|
||||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
|
|
||||||
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
|
||||||
)
|
|
||||||
|
|
|
@ -52,12 +52,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
dataset_def: Dataset,
|
dataset_def: Dataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
print("register_dataset")
|
|
||||||
# Store in kvstore
|
# Store in kvstore
|
||||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=dataset_def.json(),
|
value=dataset_def.model_dump_json(),
|
||||||
)
|
)
|
||||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"purpose, source, provider_id",
|
"purpose, source, provider_id, limit",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"eval/messages-answer",
|
"eval/messages-answer",
|
||||||
|
@ -22,16 +22,48 @@ import pytest
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
},
|
},
|
||||||
"huggingface",
|
"huggingface",
|
||||||
|
10,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"eval/messages-answer",
|
||||||
|
{
|
||||||
|
"type": "rows",
|
||||||
|
"rows": [
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"answer": "Hello, world!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the capital of France?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"answer": "Paris",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"localfs",
|
||||||
|
2,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_register_dataset(llama_stack_client, purpose, source, provider_id):
|
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit):
|
||||||
dataset = llama_stack_client.datasets.register(
|
dataset = llama_stack_client.datasets.register(
|
||||||
purpose=purpose,
|
purpose=purpose,
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
assert dataset.identifier is not None
|
assert dataset.identifier is not None
|
||||||
assert dataset.provider_id == provider_id
|
assert dataset.provider_id == provider_id
|
||||||
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=10)
|
iterrow_response = llama_stack_client.datasets.iterrows(
|
||||||
assert len(iterrow_response.data) == 10
|
dataset.identifier, limit=limit
|
||||||
assert iterrow_response.next_index is not None
|
)
|
||||||
|
assert len(iterrow_response.data) == limit
|
||||||
|
|
||||||
|
dataset_list = llama_stack_client.datasets.list()
|
||||||
|
assert dataset.identifier in [d.identifier for d in dataset_list]
|
||||||
|
|
||||||
|
llama_stack_client.datasets.unregister(dataset.identifier)
|
||||||
|
dataset_list = llama_stack_client.datasets.list()
|
||||||
|
assert dataset.identifier not in [d.identifier for d in dataset_list]
|
||||||
|
|
|
@ -10,11 +10,27 @@ from rich.pretty import pprint
|
||||||
|
|
||||||
def test_register_dataset():
|
def test_register_dataset():
|
||||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
# dataset = client.datasets.register(
|
||||||
|
# purpose="eval/messages-answer",
|
||||||
|
# source={
|
||||||
|
# "type": "uri",
|
||||||
|
# "uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
# },
|
||||||
|
# )
|
||||||
dataset = client.datasets.register(
|
dataset = client.datasets.register(
|
||||||
purpose="eval/messages-answer",
|
purpose="eval/messages-answer",
|
||||||
source={
|
source={
|
||||||
"type": "uri",
|
"type": "rows",
|
||||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
"rows": [
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"answer": "Hello, world!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||||
|
"answer": "Paris",
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataset_id = dataset.identifier
|
dataset_id = dataset.identifier
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue