skeleton dataset / datasetio

This commit is contained in:
Xi Yan 2024-10-22 11:22:39 -07:00
parent 668a495aba
commit e8de70fdbe
12 changed files with 233 additions and 2 deletions

View file

@ -14,11 +14,12 @@ from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@ -30,18 +31,21 @@ RoutableObject = Union[
ModelDef,
ShieldDef,
MemoryBankDef,
DatasetDef,
]
RoutableObjectWithProvider = Union[
ModelDefWithProvider,
ShieldDefWithProvider,
MemoryBankDefWithProvider,
DatasetDefWithProvider,
]
RoutedProtocol = Union[
Inference,
Safety,
Memory,
DatasetIO,
]

View file

@ -35,6 +35,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.memory_banks,
router_api=Api.memory,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
]

View file

@ -12,6 +12,8 @@ from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
@ -38,6 +40,8 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
Api.datasets: Datasets,
Api.datasetio: DatasetIO,
}

View file

@ -8,6 +8,7 @@ from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
DatasetsRoutingTable,
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
@ -23,6 +24,7 @@ async def get_routing_table_impl(
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
@ -33,12 +35,13 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -6,11 +6,13 @@
from typing import Any, AsyncGenerator, Dict, List
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.distribution.datatypes import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
class MemoryRouter(Memory):
@ -156,3 +158,33 @@ class SafetyRouter(Safety):
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
)

View file

@ -11,6 +11,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
@ -190,3 +191,23 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
self, memory_bank: MemoryBankDefWithProvider
) -> None:
await self.register_object(memory_bank)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[DatasetDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_dataset(
self, dataset_identifier: str
) -> Optional[ModelDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
await self.register_object(dataset_def)
async def delete_dataset(self, dataset_identifier: str) -> None:
# TODO: pass through for now
return

View file

@ -10,6 +10,8 @@ from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
@ -22,12 +24,14 @@ class Api(Enum):
safety = "safety"
agents = "agents"
memory = "memory"
datasetio = "datasetio"
telemetry = "telemetry"
models = "models"
shields = "shields"
memory_banks = "memory_banks"
datasets = "datasets"
# built-in API
inspect = "inspect"
@ -51,6 +55,12 @@ class MemoryBanksProtocolPrivate(Protocol):
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
class DatasetsProtocolPrivate(Protocol):
async def list_datasets(self) -> List[DatasetDef]: ...
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceDatasetIOConfig
async def get_provider_impl(
config: MetaReferenceDatasetIOConfig,
_deps,
):
from .datasetio import MetaReferenceDatasetioImpl
impl = MetaReferenceDatasetioImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403
class MetaReferenceDatasetIOConfig(BaseModel): ...

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from .config import MetaReferenceDatasetIOConfig
class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: MetaReferenceDatasetIOConfig) -> None:
self.config = config
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_dataset(
self,
memory_bank: DatasetDef,
) -> None:
print("register dataset")
async def list_datasets(self) -> List[DatasetDef]:
print("list datasets")
return []
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
print("get rows paginated")
return PaginatedRowsResult(rows=[], total_count=1, next_page_token=None)

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.datasetio,
provider_type="meta-reference",
pip_packages=[],
module="llama_stack.providers.impls.meta_reference.datasetio",
config_class="llama_stack.providers.impls.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
api_dependencies=[],
),
remote_provider_spec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.datasetio.sample",
config_class="llama_stack.providers.adapters.datasetio.sample.SampleConfig",
),
),
]

View file

@ -0,0 +1,52 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- safety
- agents
- models
- memory
- memory_banks
- inference
- datasets
- datasetio
providers:
datasetio:
- provider_id: meta0
provider_type: meta-reference
config: {}
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
memory:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
safety:
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M