typing comment, dataset -> dataset_id

This commit is contained in:
Xi Yan 2024-10-25 12:59:47 -07:00
parent 575e51eb76
commit 81ebd1ea92
2 changed files with 21 additions and 13 deletions

View file

@ -38,11 +38,6 @@ class UnionType(BaseModel):
type: Literal["union"] = "union" type: Literal["union"] = "union"
class CustomType(BaseModel):
type: Literal["custom"] = "custom"
validator_class: str
class ChatCompletionInputType(BaseModel): class ChatCompletionInputType(BaseModel):
# expects List[Message] for messages # expects List[Message] for messages
type: Literal["chat_completion_input"] = "chat_completion_input" type: Literal["chat_completion_input"] = "chat_completion_input"
@ -74,3 +69,16 @@ ParamType = Annotated[
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
# TODO: recursive definition of ParamType in these containers
# will cause infinite recursion in OpenAPI generation script
# since we are going with ChatCompletionInputType and CompletionInputType
# we don't need to worry about ArrayType/ObjectType/UnionType for now
# ArrayType.model_rebuild()
# ObjectType.model_rebuild()
# UnionType.model_rebuild()
# class CustomType(BaseModel):
# type: Literal["custom"] = "custom"
# validator_class: str

View file

@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
job_uuid: str job_uuid: str
model: str model: str
dataset: str dataset_id: str
validation_dataset: str validation_dataset_id: str
algorithm: FinetuningAlgorithm algorithm: FinetuningAlgorithm
algorithm_config: Union[ algorithm_config: Union[
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
finetuned_model: URL finetuned_model: URL
dataset: str dataset_id: str
validation_dataset: str validation_dataset_id: str
algorithm: RLHFAlgorithm algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig] algorithm_config: Union[DPOAlignmentConfig]
@ -181,8 +181,8 @@ class PostTraining(Protocol):
self, self,
job_uuid: str, job_uuid: str,
model: str, model: str,
dataset: str, dataset_id: str,
validation_dataset: str, validation_dataset_id: str,
algorithm: FinetuningAlgorithm, algorithm: FinetuningAlgorithm,
algorithm_config: Union[ algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
@ -198,8 +198,8 @@ class PostTraining(Protocol):
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: URL, finetuned_model: URL,
dataset: str, dataset_id: str,
validation_dataset: str, validation_dataset_id: str,
algorithm: RLHFAlgorithm, algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig], algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig, optimizer_config: OptimizerConfig,