From 81ebd1ea925292af3370797957c7271192623f1a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 25 Oct 2024 12:59:47 -0700 Subject: [PATCH] typing comment, dataset -> dataset_id --- llama_stack/apis/common/type_system.py | 18 +++++++++++++----- .../apis/post_training/post_training.py | 16 ++++++++-------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index cffc4e936..4808e8238 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -38,11 +38,6 @@ class UnionType(BaseModel): type: Literal["union"] = "union" -class CustomType(BaseModel): - type: Literal["custom"] = "custom" - validator_class: str - - class ChatCompletionInputType(BaseModel): # expects List[Message] for messages type: Literal["chat_completion_input"] = "chat_completion_input" @@ -74,3 +69,16 @@ ParamType = Annotated[ ], Field(discriminator="type"), ] + +# TODO: recursive definition of ParamType in these containers +# will cause infinite recursion in OpenAPI generation script +# since we are going with ChatCompletionInputType and CompletionInputType +# we don't need to worry about ArrayType/ObjectType/UnionType for now +# ArrayType.model_rebuild() +# ObjectType.model_rebuild() +# UnionType.model_rebuild() + + +# class CustomType(BaseModel): +# type: Literal["custom"] = "custom" +# validator_class: str diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index abf21f6c2..eb4992cc6 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel): job_uuid: str model: str - dataset: str - validation_dataset: str + dataset_id: str + validation_dataset_id: str algorithm: FinetuningAlgorithm algorithm_config: Union[ @@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel): finetuned_model: URL - dataset: str - validation_dataset: str + dataset_id: str + validation_dataset_id: str algorithm: RLHFAlgorithm algorithm_config: Union[DPOAlignmentConfig] @@ -181,8 +181,8 @@ class PostTraining(Protocol): self, job_uuid: str, model: str, - dataset: str, - validation_dataset: str, + dataset_id: str, + validation_dataset_id: str, algorithm: FinetuningAlgorithm, algorithm_config: Union[ LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig @@ -198,8 +198,8 @@ class PostTraining(Protocol): self, job_uuid: str, finetuned_model: URL, - dataset: str, - validation_dataset: str, + dataset_id: str, + validation_dataset_id: str, algorithm: RLHFAlgorithm, algorithm_config: Union[DPOAlignmentConfig], optimizer_config: OptimizerConfig,