skeleton dataset / datasetio

This commit is contained in:
Xi Yan 2024-10-22 11:22:39 -07:00
parent 668a495aba
commit e8de70fdbe
12 changed files with 233 additions and 2 deletions

View file

@ -10,6 +10,8 @@ from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
@ -22,12 +24,14 @@ class Api(Enum):
safety = "safety"
agents = "agents"
memory = "memory"
datasetio = "datasetio"
telemetry = "telemetry"
models = "models"
shields = "shields"
memory_banks = "memory_banks"
datasets = "datasets"
# built-in API
inspect = "inspect"
@ -51,6 +55,12 @@ class MemoryBanksProtocolPrivate(Protocol):
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
class DatasetsProtocolPrivate(Protocol):
async def list_datasets(self) -> List[DatasetDef]: ...
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceDatasetIOConfig
async def get_provider_impl(
config: MetaReferenceDatasetIOConfig,
_deps,
):
from .datasetio import MetaReferenceDatasetioImpl
impl = MetaReferenceDatasetioImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403
class MetaReferenceDatasetIOConfig(BaseModel): ...

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from .config import MetaReferenceDatasetIOConfig
class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: MetaReferenceDatasetIOConfig) -> None:
self.config = config
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_dataset(
self,
memory_bank: DatasetDef,
) -> None:
print("register dataset")
async def list_datasets(self) -> List[DatasetDef]:
print("list datasets")
return []
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
print("get rows paginated")
return PaginatedRowsResult(rows=[], total_count=1, next_page_token=None)

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.datasetio,
provider_type="meta-reference",
pip_packages=[],
module="llama_stack.providers.impls.meta_reference.datasetio",
config_class="llama_stack.providers.impls.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
api_dependencies=[],
),
remote_provider_spec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.datasetio.sample",
config_class="llama_stack.providers.adapters.datasetio.sample.SampleConfig",
),
),
]