llama-stack/llama_stack/providers/inline/datasetio/localfs/datasetio.py
Sixian Yi caf1dac114
unregister API for dataset (#507)
# What does this PR do?

1) Implement `unregister_dataset(dataset_id)` API in both llama stack
routing table and providers: It removes {dataset_id -> Dataset} mapping
from routing table and removes the dataset_id references in provider as
well (ex. for huggingface, we use a KV store to store the dataset id =>
dataset. we delete it during unregistering as well)

2) expose the datasets/unregister_dataset api endpoint 

## Test Plan

**Unit test:** 

`
pytest llama_stack/providers/tests/datasetio/test_datasetio.py -m
"huggingface" -v -s --tb=short --disable-warnings
`

**Test on endpoint:**
tested llama stack using an ollama distribution template:
1) start an ollama server 
2) Start a llama stack server with the default ollama distribution
config + dataset/datasetsio APIs + datasetio provider
```
---- .../ollama-run.yaml
...
apis:
- agents
- inference
- memory
- safety
- telemetry
- datasetio
- datasets
providers:
  datasetio:
  - provider_id: localfs
    provider_type: inline::localfs
    config: {}
...
```
   saw that the new API showed up in startup script
   
  ```
Serving API datasets
 GET /alpha/datasets/get
 GET /alpha/datasets/list
 POST /alpha/datasets/register
 POST /alpha/datasets/unregister
```

3) query `/alpha/datasets/unregister` through curl (since we have not implemented unregister api in llama stack client)

```
(base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register
--dataset-id sixian --url
https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst
--schema {}
(base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list
┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ metadata ┃ type    ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩
│ sixian     │ localfs     │ {}       │ dataset │
└────────────┴─────────────┴──────────┴─────────┘
(base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register
--dataset-id sixian2 --url
https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst
--schema {}
(base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list
┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ metadata ┃ type    ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩
│ sixian     │ localfs     │ {}       │ dataset │
│ sixian2    │ localfs     │ {}       │ dataset │
└────────────┴─────────────┴──────────┴─────────┘
(base) sxyi@sxyi-mbp llama-stack % curl
http://localhost:5001/alpha/datasets/unregister \
-H "Content-Type: application/json" \
-d '{"dataset_id": "sixian"}'
null%

(base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list
┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ metadata ┃ type    ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩
│ sixian2    │ localfs     │ {}       │ dataset │
└────────────┴─────────────┴──────────┴─────────┘
(base) sxyi@sxyi-mbp llama-stack % curl
http://localhost:5001/alpha/datasets/unregister \
-H "Content-Type: application/json" \
-d '{"dataset_id": "sixian2"}'
null%

(base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list
```

## Sources


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2024-12-03 21:18:30 -08:00

133 lines
4.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from abc import ABC, abstractmethod
from dataclasses import dataclass
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from .config import LocalFSDatasetIOConfig
class BaseDataset(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()
@abstractmethod
def load(self):
raise NotImplementedError()
@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
self.df = None
def __len__(self) -> int:
assert self.df is not None, "Dataset not loaded. Please call .load() first"
return len(self.df)
def __getitem__(self, idx):
assert self.df is not None, "Dataset not loaded. Please call .load() first"
if isinstance(idx, slice):
return self.df.iloc[idx].to_dict(orient="records")
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
self.df = self._validate_dataset_schema(df)
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset: Dataset,
) -> None:
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
async def unregister_dataset(self, dataset_id: str) -> None:
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
rows = dataset_info.dataset_impl[start:end]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
)