mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 18:33:09 +00:00
skeleton dataset / datasetio
This commit is contained in:
parent
668a495aba
commit
e8de70fdbe
12 changed files with 233 additions and 2 deletions
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .routing_tables import (
|
||||
DatasetsRoutingTable,
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
|
|
@ -23,6 +24,7 @@ async def get_routing_table_impl(
|
|||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
}
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
|
@ -33,12 +35,13 @@ async def get_routing_table_impl(
|
|||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
|
||||
from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter
|
||||
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
"datasetio": DatasetIORouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
from llama_stack.apis.datasetio.datasetio import DatasetIO
|
||||
from llama_stack.distribution.datatypes import RoutingTable
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
|
|
@ -156,3 +158,33 @@ class SafetyRouter(Safety):
|
|||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
class DatasetIORouter(DatasetIO):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
dataset_id
|
||||
).get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=rows_in_page,
|
||||
page_token=page_token,
|
||||
filter_condition=filter_condition,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
|
@ -190,3 +191,23 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None:
|
||||
await self.register_object(memory_bank)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def get_dataset(
|
||||
self, dataset_identifier: str
|
||||
) -> Optional[ModelDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
|
||||
await self.register_object(dataset_def)
|
||||
|
||||
async def delete_dataset(self, dataset_identifier: str) -> None:
|
||||
# TODO: pass through for now
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue