mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
wip
This commit is contained in:
parent
821810657f
commit
aefa84e70a
5 changed files with 117 additions and 54 deletions
|
@ -20,7 +20,7 @@ class DatasetDef(BaseModel):
|
||||||
identifier: str = Field(
|
identifier: str = Field(
|
||||||
description="A unique name for the dataset",
|
description="A unique name for the dataset",
|
||||||
)
|
)
|
||||||
columns_schema: Dict[str, ParamType] = Field(
|
dataset_schema: Dict[str, ParamType] = Field(
|
||||||
description="The schema definition for this dataset",
|
description="The schema definition for this dataset",
|
||||||
)
|
)
|
||||||
url: URL
|
url: URL
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import io
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
@ -14,6 +14,7 @@ from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||||
|
|
||||||
from .config import MetaReferenceDatasetIOConfig
|
from .config import MetaReferenceDatasetIOConfig
|
||||||
|
|
||||||
|
@ -57,6 +58,15 @@ class PandasDataframeDataset(BaseDataset):
|
||||||
else:
|
else:
|
||||||
return self.df.iloc[idx].to_dict()
|
return self.df.iloc[idx].to_dict()
|
||||||
|
|
||||||
|
def validate_dataset_schema(self) -> None:
|
||||||
|
if self.df is None:
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
print(self.dataset_def.dataset_schema)
|
||||||
|
# get columns names
|
||||||
|
# columns = self.df[self.dataset_def.dataset_schema.keys()]
|
||||||
|
print(self.df.columns)
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
if self.df is not None:
|
if self.df is not None:
|
||||||
return
|
return
|
||||||
|
@ -105,6 +115,7 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
dataset_def: DatasetDef,
|
dataset_def: DatasetDef,
|
||||||
) -> None:
|
) -> None:
|
||||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
|
dataset_impl.validate_dataset_schema()
|
||||||
self.dataset_infos[dataset_def.identifier] = DatasetInfo(
|
self.dataset_infos[dataset_def.identifier] = DatasetInfo(
|
||||||
dataset_def=dataset_def,
|
dataset_def=dataset_def,
|
||||||
dataset_impl=dataset_impl,
|
dataset_impl=dataset_impl,
|
||||||
|
|
|
@ -36,7 +36,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=(
|
pip_packages=(
|
||||||
META_REFERENCE_DEPS
|
META_REFERENCE_DEPS
|
||||||
+ [
|
+ [
|
||||||
"fbgemm-gpu==0.8.0",
|
"fbgemm-gpu==1.0.0",
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
module="llama_stack.providers.impls.meta_reference.inference",
|
module="llama_stack.providers.impls.meta_reference.inference",
|
||||||
|
|
5
llama_stack/providers/tests/datasetio/test_dataset.jsonl
Normal file
5
llama_stack/providers/tests/datasetio/test_dataset.jsonl
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
{"input_query": "What is the capital of France?", "generated_answer": "London", "expected_answer": "Paris"}
|
||||||
|
{"input_query": "Who is the CEO of Meta?", "generated_answer": "Mark Zuckerberg", "expected_answer": "Mark Zuckerberg"}
|
||||||
|
{"input_query": "What is the largest planet in our solar system?", "generated_answer": "Jupiter", "expected_answer": "Jupiter"}
|
||||||
|
{"input_query": "What is the smallest country in the world?", "generated_answer": "China", "expected_answer": "Vatican City"}
|
||||||
|
{"input_query": "What is the currency of Japan?", "generated_answer": "Yen", "expected_answer": "Yen"}
|
|
@ -8,8 +8,13 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
|
@ -41,6 +46,22 @@ async def datasetio_settings():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def data_url_from_file(file_path: str) -> str:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
file_content = file.read()
|
||||||
|
|
||||||
|
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
print(mime_type)
|
||||||
|
|
||||||
|
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||||
|
|
||||||
|
return data_url
|
||||||
|
|
||||||
|
|
||||||
async def register_dataset(datasets_impl: Datasets):
|
async def register_dataset(datasets_impl: Datasets):
|
||||||
dataset = DatasetDefWithProvider(
|
dataset = DatasetDefWithProvider(
|
||||||
identifier="test_dataset",
|
identifier="test_dataset",
|
||||||
|
@ -48,62 +69,88 @@ async def register_dataset(datasets_impl: Datasets):
|
||||||
url=URL(
|
url=URL(
|
||||||
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
||||||
),
|
),
|
||||||
columns_schema={},
|
dataset_schema={},
|
||||||
)
|
)
|
||||||
await datasets_impl.register_dataset(dataset)
|
await datasets_impl.register_dataset(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
async def register_local_dataset(datasets_impl: Datasets):
|
||||||
|
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl"
|
||||||
|
test_jsonl_url = data_url_from_file(str(test_file))
|
||||||
|
dataset = DatasetDefWithProvider(
|
||||||
|
identifier="test_dataset",
|
||||||
|
provider_id=os.environ["PROVIDER_ID"],
|
||||||
|
url=URL(
|
||||||
|
uri=test_jsonl_url,
|
||||||
|
),
|
||||||
|
dataset_schema={
|
||||||
|
"generated_answer": StringType(),
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"input_query": StringType(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await datasets_impl.register_dataset(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_datasets_list(datasetio_settings):
|
||||||
|
# # NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
|
# # but so far we don't have an unregister API unfortunately, so be careful
|
||||||
|
# datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
|
# response = await datasets_impl.list_datasets()
|
||||||
|
# assert isinstance(response, list)
|
||||||
|
# assert len(response) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_datasets_register(datasetio_settings):
|
||||||
|
# # NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
|
# # but so far we don't have an unregister API unfortunately, so be careful
|
||||||
|
# datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
|
# await register_dataset(datasets_impl)
|
||||||
|
|
||||||
|
# response = await datasets_impl.list_datasets()
|
||||||
|
# assert isinstance(response, list)
|
||||||
|
# assert len(response) == 1
|
||||||
|
|
||||||
|
# # register same dataset with same id again will fail
|
||||||
|
# await register_dataset(datasets_impl)
|
||||||
|
# response = await datasets_impl.list_datasets()
|
||||||
|
# assert isinstance(response, list)
|
||||||
|
# assert len(response) == 1
|
||||||
|
# assert response[0].identifier == "test_dataset"
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_get_rows_paginated(datasetio_settings):
|
||||||
|
# datasetio_impl = datasetio_settings["datasetio_impl"]
|
||||||
|
# datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
|
# await register_dataset(datasets_impl)
|
||||||
|
|
||||||
|
# response = await datasetio_impl.get_rows_paginated(
|
||||||
|
# dataset_id="test_dataset",
|
||||||
|
# rows_in_page=3,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# assert isinstance(response.rows, list)
|
||||||
|
# assert len(response.rows) == 3
|
||||||
|
# assert response.next_page_token == "3"
|
||||||
|
|
||||||
|
# # iterate over all rows
|
||||||
|
# response = await datasetio_impl.get_rows_paginated(
|
||||||
|
# dataset_id="test_dataset",
|
||||||
|
# rows_in_page=10,
|
||||||
|
# page_token=response.next_page_token,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# assert isinstance(response.rows, list)
|
||||||
|
# assert len(response.rows) == 10
|
||||||
|
# assert response.next_page_token == "13"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_datasets_list(datasetio_settings):
|
async def test_datasets_validation(datasetio_settings):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
datasets_impl = datasetio_settings["datasets_impl"]
|
datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
response = await datasets_impl.list_datasets()
|
await register_local_dataset(datasets_impl)
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_datasets_register(datasetio_settings):
|
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
|
||||||
datasets_impl = datasetio_settings["datasets_impl"]
|
|
||||||
await register_dataset(datasets_impl)
|
|
||||||
|
|
||||||
response = await datasets_impl.list_datasets()
|
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 1
|
|
||||||
|
|
||||||
# register same dataset with same id again will fail
|
|
||||||
await register_dataset(datasets_impl)
|
|
||||||
response = await datasets_impl.list_datasets()
|
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 1
|
|
||||||
assert response[0].identifier == "test_dataset"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_rows_paginated(datasetio_settings):
|
|
||||||
datasetio_impl = datasetio_settings["datasetio_impl"]
|
|
||||||
datasets_impl = datasetio_settings["datasets_impl"]
|
|
||||||
await register_dataset(datasets_impl)
|
|
||||||
|
|
||||||
response = await datasetio_impl.get_rows_paginated(
|
|
||||||
dataset_id="test_dataset",
|
|
||||||
rows_in_page=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response.rows, list)
|
|
||||||
assert len(response.rows) == 3
|
|
||||||
assert response.next_page_token == "3"
|
|
||||||
|
|
||||||
# iterate over all rows
|
|
||||||
response = await datasetio_impl.get_rows_paginated(
|
|
||||||
dataset_id="test_dataset",
|
|
||||||
rows_in_page=10,
|
|
||||||
page_token=response.next_page_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response.rows, list)
|
|
||||||
assert len(response.rows) == 10
|
|
||||||
assert response.next_page_token == "13"
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue