skeleton dataset / datasetio

This commit is contained in:
Xi Yan 2024-10-22 11:22:39 -07:00
parent 668a495aba
commit e8de70fdbe
12 changed files with 233 additions and 2 deletions

View file

@ -14,11 +14,12 @@ from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@ -30,18 +31,21 @@ RoutableObject = Union[
ModelDef,
ShieldDef,
MemoryBankDef,
DatasetDef,
]
RoutableObjectWithProvider = Union[
ModelDefWithProvider,
ShieldDefWithProvider,
MemoryBankDefWithProvider,
DatasetDefWithProvider,
]
RoutedProtocol = Union[
Inference,
Safety,
Memory,
DatasetIO,
]

View file

@ -35,6 +35,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.memory_banks,
router_api=Api.memory,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
]

View file

@ -12,6 +12,8 @@ from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
@ -38,6 +40,8 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
Api.datasets: Datasets,
Api.datasetio: DatasetIO,
}

View file

@ -8,6 +8,7 @@ from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from .routing_tables import (
DatasetsRoutingTable,
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
@ -23,6 +24,7 @@ async def get_routing_table_impl(
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
@ -33,12 +35,13 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -6,11 +6,13 @@
from typing import Any, AsyncGenerator, Dict, List
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.distribution.datatypes import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
class MemoryRouter(Memory):
@ -156,3 +158,33 @@ class SafetyRouter(Safety):
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
)

View file

@ -11,6 +11,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
@ -190,3 +191,23 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
self, memory_bank: MemoryBankDefWithProvider
) -> None:
await self.register_object(memory_bank)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[DatasetDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_dataset(
self, dataset_identifier: str
) -> Optional[ModelDefWithProvider]:
return self.get_object_by_identifier(identifier)
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
await self.register_object(dataset_def)
async def delete_dataset(self, dataset_identifier: str) -> None:
# TODO: pass through for now
return