mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.common.responses import PaginatedResponse
|
|
from llama_stack.apis.datasetio import DatasetIO
|
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import RoutingTable
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class DatasetIORouter(DatasetIO):
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
) -> None:
|
|
logger.debug("Initializing DatasetIORouter")
|
|
self.routing_table = routing_table
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("DatasetIORouter.initialize")
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("DatasetIORouter.shutdown")
|
|
pass
|
|
|
|
async def register_dataset(
|
|
self,
|
|
purpose: DatasetPurpose,
|
|
source: DataSource,
|
|
metadata: dict[str, Any] | None = None,
|
|
dataset_id: str | None = None,
|
|
) -> None:
|
|
logger.debug(
|
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
|
)
|
|
await self.routing_table.register_dataset(
|
|
purpose=purpose,
|
|
source=source,
|
|
metadata=metadata,
|
|
dataset_id=dataset_id,
|
|
)
|
|
|
|
async def iterrows(
|
|
self,
|
|
dataset_id: str,
|
|
start_index: int | None = None,
|
|
limit: int | None = None,
|
|
) -> PaginatedResponse:
|
|
logger.debug(
|
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
|
)
|
|
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
|
dataset_id=dataset_id,
|
|
start_index=start_index,
|
|
limit=limit,
|
|
)
|
|
|
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
|
dataset_id=dataset_id,
|
|
rows=rows,
|
|
)
|