forked from phoenix-oss/llama-stack-mirror
# What does this PR do? 1) Implement `unregister_dataset(dataset_id)` API in both llama stack routing table and providers: It removes {dataset_id -> Dataset} mapping from routing table and removes the dataset_id references in provider as well (ex. for huggingface, we use a KV store to store the dataset id => dataset. we delete it during unregistering as well) 2) expose the datasets/unregister_dataset api endpoint ## Test Plan **Unit test:** ` pytest llama_stack/providers/tests/datasetio/test_datasetio.py -m "huggingface" -v -s --tb=short --disable-warnings ` **Test on endpoint:** tested llama stack using an ollama distribution template: 1) start an ollama server 2) Start a llama stack server with the default ollama distribution config + dataset/datasetsio APIs + datasetio provider ``` ---- .../ollama-run.yaml ... apis: - agents - inference - memory - safety - telemetry - datasetio - datasets providers: datasetio: - provider_id: localfs provider_type: inline::localfs config: {} ... ``` saw that the new API showed up in startup script ``` Serving API datasets GET /alpha/datasets/get GET /alpha/datasets/list POST /alpha/datasets/register POST /alpha/datasets/unregister ``` 3) query `/alpha/datasets/unregister` through curl (since we have not implemented unregister api in llama stack client) ``` (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian2 --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian │ localfs │ {} │ dataset │ │ sixian2 │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian2 │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian2"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ``` ## Sources ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
from typing import Optional
|
|
|
|
import pandas
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
|
|
from llama_stack.apis.datasetio import * # noqa: F403
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
|
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
|
|
|
from .config import LocalFSDatasetIOConfig
|
|
|
|
|
|
class BaseDataset(ABC):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@abstractmethod
|
|
def __len__(self) -> int:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def __getitem__(self, idx):
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def load(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
@dataclass
|
|
class DatasetInfo:
|
|
dataset_def: Dataset
|
|
dataset_impl: BaseDataset
|
|
|
|
|
|
class PandasDataframeDataset(BaseDataset):
|
|
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.dataset_def = dataset_def
|
|
self.df = None
|
|
|
|
def __len__(self) -> int:
|
|
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
|
return len(self.df)
|
|
|
|
def __getitem__(self, idx):
|
|
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
|
if isinstance(idx, slice):
|
|
return self.df.iloc[idx].to_dict(orient="records")
|
|
else:
|
|
return self.df.iloc[idx].to_dict()
|
|
|
|
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
|
# note that we will drop any columns in dataset that are not in the schema
|
|
df = df[self.dataset_def.dataset_schema.keys()]
|
|
# check all columns in dataset schema are present
|
|
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
|
# TODO: type checking against column types in dataset schema
|
|
return df
|
|
|
|
def load(self) -> None:
|
|
if self.df is not None:
|
|
return
|
|
|
|
df = get_dataframe_from_url(self.dataset_def.url)
|
|
if df is None:
|
|
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
|
|
|
self.df = self._validate_dataset_schema(df)
|
|
|
|
|
|
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
|
|
self.config = config
|
|
# local registry for keeping track of datasets within the provider
|
|
self.dataset_infos = {}
|
|
|
|
async def initialize(self) -> None: ...
|
|
|
|
async def shutdown(self) -> None: ...
|
|
|
|
async def register_dataset(
|
|
self,
|
|
dataset: Dataset,
|
|
) -> None:
|
|
dataset_impl = PandasDataframeDataset(dataset)
|
|
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
|
dataset_def=dataset,
|
|
dataset_impl=dataset_impl,
|
|
)
|
|
|
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
del self.dataset_infos[dataset_id]
|
|
|
|
async def get_rows_paginated(
|
|
self,
|
|
dataset_id: str,
|
|
rows_in_page: int,
|
|
page_token: Optional[str] = None,
|
|
filter_condition: Optional[str] = None,
|
|
) -> PaginatedRowsResult:
|
|
dataset_info = self.dataset_infos.get(dataset_id)
|
|
dataset_info.dataset_impl.load()
|
|
|
|
if page_token and not page_token.isnumeric():
|
|
raise ValueError("Invalid page_token")
|
|
|
|
if page_token is None or len(page_token) == 0:
|
|
next_page_token = 0
|
|
else:
|
|
next_page_token = int(page_token)
|
|
|
|
start = next_page_token
|
|
if rows_in_page == -1:
|
|
end = len(dataset_info.dataset_impl)
|
|
else:
|
|
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
|
|
|
rows = dataset_info.dataset_impl[start:end]
|
|
|
|
return PaginatedRowsResult(
|
|
rows=rows,
|
|
total_count=len(rows),
|
|
next_page_token=str(end),
|
|
)
|