mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 13:22:36 +00:00
typing comment, dataset -> dataset_id
This commit is contained in:
parent
575e51eb76
commit
81ebd1ea92
2 changed files with 21 additions and 13 deletions
|
|
@ -38,11 +38,6 @@ class UnionType(BaseModel):
|
||||||
type: Literal["union"] = "union"
|
type: Literal["union"] = "union"
|
||||||
|
|
||||||
|
|
||||||
class CustomType(BaseModel):
|
|
||||||
type: Literal["custom"] = "custom"
|
|
||||||
validator_class: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionInputType(BaseModel):
|
class ChatCompletionInputType(BaseModel):
|
||||||
# expects List[Message] for messages
|
# expects List[Message] for messages
|
||||||
type: Literal["chat_completion_input"] = "chat_completion_input"
|
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||||
|
|
@ -74,3 +69,16 @@ ParamType = Annotated[
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# TODO: recursive definition of ParamType in these containers
|
||||||
|
# will cause infinite recursion in OpenAPI generation script
|
||||||
|
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||||
|
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||||
|
# ArrayType.model_rebuild()
|
||||||
|
# ObjectType.model_rebuild()
|
||||||
|
# UnionType.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
|
# class CustomType(BaseModel):
|
||||||
|
# type: Literal["custom"] = "custom"
|
||||||
|
# validator_class: str
|
||||||
|
|
|
||||||
|
|
@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
dataset: str
|
dataset_id: str
|
||||||
validation_dataset: str
|
validation_dataset_id: str
|
||||||
|
|
||||||
algorithm: FinetuningAlgorithm
|
algorithm: FinetuningAlgorithm
|
||||||
algorithm_config: Union[
|
algorithm_config: Union[
|
||||||
|
|
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
|
|
||||||
finetuned_model: URL
|
finetuned_model: URL
|
||||||
|
|
||||||
dataset: str
|
dataset_id: str
|
||||||
validation_dataset: str
|
validation_dataset_id: str
|
||||||
|
|
||||||
algorithm: RLHFAlgorithm
|
algorithm: RLHFAlgorithm
|
||||||
algorithm_config: Union[DPOAlignmentConfig]
|
algorithm_config: Union[DPOAlignmentConfig]
|
||||||
|
|
@ -181,8 +181,8 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
model: str,
|
model: str,
|
||||||
dataset: str,
|
dataset_id: str,
|
||||||
validation_dataset: str,
|
validation_dataset_id: str,
|
||||||
algorithm: FinetuningAlgorithm,
|
algorithm: FinetuningAlgorithm,
|
||||||
algorithm_config: Union[
|
algorithm_config: Union[
|
||||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||||
|
|
@ -198,8 +198,8 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
finetuned_model: URL,
|
finetuned_model: URL,
|
||||||
dataset: str,
|
dataset_id: str,
|
||||||
validation_dataset: str,
|
validation_dataset_id: str,
|
||||||
algorithm: RLHFAlgorithm,
|
algorithm: RLHFAlgorithm,
|
||||||
algorithm_config: Union[DPOAlignmentConfig],
|
algorithm_config: Union[DPOAlignmentConfig],
|
||||||
optimizer_config: OptimizerConfig,
|
optimizer_config: OptimizerConfig,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue