added DPO

This commit is contained in:
Ashwin Bharambe 2024-07-11 00:01:58 -07:00
parent 7cade3acc3
commit 631328f556
4 changed files with 796 additions and 472 deletions

View file

@ -12,19 +12,6 @@ from agentic_system_types import (
SafetyViolation,
)
from finetuning_types import (
Checkpoint,
Dataset,
DoraFinetuningConfig,
FinetuningAlgorithm,
FinetuningJobLogStream,
FinetuningJobStatus,
LoraFinetuningConfig,
OptimizerConfig,
QLoraFinetuningConfig,
TrainingConfig,
)
from model_types import (
BuiltinTool,
Content,
@ -42,6 +29,21 @@ from model_types import (
URL,
)
from post_training_types import (
Checkpoint,
Dataset,
DoraFinetuningConfig,
DPOAlignmentConfig,
FinetuningAlgorithm,
LoraFinetuningConfig,
OptimizerConfig,
PostTrainingJobLogStream,
PostTrainingJobStatus,
QLoraFinetuningConfig,
RLHFAlgorithm,
TrainingConfig,
)
from pyopenapi import Info, Options, Server, Specification, webmethod
from strong_typing.schema import json_schema_type
@ -408,7 +410,7 @@ class Datasets(Protocol):
@json_schema_type
@dataclass
class FinetuningTrainRequest:
class PostTrainingSFTRequest:
"""Request to finetune a model."""
job_uuid: str
@ -432,11 +434,34 @@ class FinetuningTrainRequest:
@json_schema_type
@dataclass
class FinetuningJobStatusResponse:
class PostTrainingRLHFRequest:
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset: Dataset
validation_dataset: Dataset
algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type
@dataclass
class PostTrainingJobStatusResponse:
"""Status of a finetuning job."""
job_uuid: str
status: FinetuningJobStatus
status: PostTrainingJobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
@ -449,7 +474,7 @@ class FinetuningJobStatusResponse:
@json_schema_type
@dataclass
class FinetuningJobArtifactsResponse:
class PostTrainingJobArtifactsResponse:
"""Artifacts of a finetuning job."""
job_uuid: str
@ -458,27 +483,35 @@ class FinetuningJobArtifactsResponse:
# TODO(ashwin): metrics, evals
class Finetuning(Protocol):
@webmethod(route="/finetuning/text_generation/train")
def post_train(
class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune/")
def post_supervised_fine_tune(
self,
request: FinetuningTrainRequest,
request: PostTrainingSFTRequest,
) -> None: ...
@webmethod(route="/post_training/preference_optimize/")
def post_preference_optimize(
self,
request: PostTrainingRLHFRequest,
) -> None: ...
# sends SSE stream of logs
@webmethod(route="/finetuning/job/logs")
def get_training_log_stream(self, job_uuid: str) -> FinetuningJobLogStream: ...
@webmethod(route="/post_training/job/logs")
def get_training_log_stream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/finetuning/job/status")
def get_training_job_status(self, job_uuid: str) -> FinetuningJobStatusResponse: ...
@webmethod(route="/post_training/job/status")
def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/finetuning/job/cancel")
@webmethod(route="/post_training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/finetuning/job/artifacts")
@webmethod(route="/post_training/job/artifacts")
def get_training_job_artifacts(
self, job_uuid: str
) -> FinetuningJobArtifactsResponse: ...
) -> PostTrainingJobArtifactsResponse: ...
class LlamaStackEndpoints(
@ -487,7 +520,7 @@ class LlamaStackEndpoints(
RewardScoring,
SyntheticDataGeneration,
Datasets,
Finetuning,
PostTraining,
MemoryBanks,
): ...