wip add datatypes

This commit is contained in:
Xi Yan 2024-10-10 19:56:19 -07:00
parent 99ed1425fc
commit 9816c9aae6
5 changed files with 175 additions and 57 deletions

View file

@ -13,20 +13,59 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
TDatasetRow = TypeVar("TDatasetRow")
# A sample (row) from raw dataset
TDatasetSample = TypeVar("TDatasetSample")
@json_schema_type
class DatasetRow(BaseModel): ...
class DatasetSample(BaseModel): ...
@json_schema_type
class DictSample(DatasetRow):
class DictSample(DatasetSample):
data: Dict[str, Any]
@json_schema_type
class Generation(BaseModel): ...
class ProcessedDictSample(DatasetSample):
data: Dict[str, Any]
preprocessed: Dict[str, Any]
prediction: Dict[str, Any]
postprocessed: Dict[str, Any]
# # A sample (row) after preprocessing the raw dataset
# TPreprocessedSample = TypeVar("TPreprocessedSample")
# @json_schema_type
# class PreprocessedSample(BaseModel): ...
# @json_schema_type
# class InferencePreprocessedSample(PreprocessedSample):
# # TODO: either keep it generic or specific to inference API
# # messages: List[Message]
# data: Dict[str, Any]
# # A sample (row) from model prediction output
# TPredictionSample = TypeVar("TPredictionSample")
# @json_schema_type
# class PredictionSample(BaseModel): ...
# @json_schema_type
# class InferencePredictionSample(PredictionSample):
# data: Dict[str, Any]
# # A sample (row) from post-processed output
# TPostprocessedSample = TypeVar("TPostprocessedSample")
# @json_schema_type
# class PostprocessedSample(BaseModel): ...
# @json_schema_type
# class InferencePostprocessedSample(PredictionSample):
# data: Dict[str, Any]
@json_schema_type
@ -70,16 +109,17 @@ DatasetDef = Annotated[
]
class BaseDataset(ABC, Generic[TDatasetRow]):
class BaseDataset(ABC, Generic[TDatasetSample]):
def __init__(self) -> None:
self.type: str = self.__class__.__name__
@property
@abstractmethod
def __iter__(self) -> Iterator[TDatasetRow]:
def dataset_id(self) -> str:
raise NotImplementedError()
@abstractmethod
def load(self) -> None:
def __iter__(self) -> Iterator[TDatasetSample]:
raise NotImplementedError()
@abstractmethod
@ -90,6 +130,10 @@ class BaseDataset(ABC, Generic[TDatasetRow]):
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def load(self) -> None:
raise NotImplementedError()
class Datasets(Protocol):
@webmethod(route="/datasets/create")