mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-09 11:20:58 +00:00
wip add datatypes
This commit is contained in:
parent
99ed1425fc
commit
9816c9aae6
5 changed files with 175 additions and 57 deletions
|
|
@ -13,20 +13,59 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
TDatasetRow = TypeVar("TDatasetRow")
|
||||
# A sample (row) from raw dataset
|
||||
TDatasetSample = TypeVar("TDatasetSample")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetRow(BaseModel): ...
|
||||
class DatasetSample(BaseModel): ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DictSample(DatasetRow):
|
||||
class DictSample(DatasetSample):
|
||||
data: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Generation(BaseModel): ...
|
||||
class ProcessedDictSample(DatasetSample):
|
||||
data: Dict[str, Any]
|
||||
preprocessed: Dict[str, Any]
|
||||
prediction: Dict[str, Any]
|
||||
postprocessed: Dict[str, Any]
|
||||
|
||||
|
||||
# # A sample (row) after preprocessing the raw dataset
|
||||
# TPreprocessedSample = TypeVar("TPreprocessedSample")
|
||||
|
||||
# @json_schema_type
|
||||
# class PreprocessedSample(BaseModel): ...
|
||||
|
||||
# @json_schema_type
|
||||
# class InferencePreprocessedSample(PreprocessedSample):
|
||||
# # TODO: either keep it generic or specific to inference API
|
||||
# # messages: List[Message]
|
||||
# data: Dict[str, Any]
|
||||
|
||||
# # A sample (row) from model prediction output
|
||||
# TPredictionSample = TypeVar("TPredictionSample")
|
||||
|
||||
# @json_schema_type
|
||||
# class PredictionSample(BaseModel): ...
|
||||
|
||||
# @json_schema_type
|
||||
# class InferencePredictionSample(PredictionSample):
|
||||
# data: Dict[str, Any]
|
||||
|
||||
|
||||
# # A sample (row) from post-processed output
|
||||
# TPostprocessedSample = TypeVar("TPostprocessedSample")
|
||||
|
||||
# @json_schema_type
|
||||
# class PostprocessedSample(BaseModel): ...
|
||||
|
||||
# @json_schema_type
|
||||
# class InferencePostprocessedSample(PredictionSample):
|
||||
# data: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -70,16 +109,17 @@ DatasetDef = Annotated[
|
|||
]
|
||||
|
||||
|
||||
class BaseDataset(ABC, Generic[TDatasetRow]):
|
||||
class BaseDataset(ABC, Generic[TDatasetSample]):
|
||||
def __init__(self) -> None:
|
||||
self.type: str = self.__class__.__name__
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[TDatasetRow]:
|
||||
def dataset_id(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
def __iter__(self) -> Iterator[TDatasetSample]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -90,6 +130,10 @@ class BaseDataset(ABC, Generic[TDatasetRow]):
|
|||
def __len__(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets/create")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue