mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
added DPO
This commit is contained in:
parent
7cade3acc3
commit
631328f556
4 changed files with 796 additions and 472 deletions
|
@ -12,19 +12,6 @@ from agentic_system_types import (
|
|||
SafetyViolation,
|
||||
)
|
||||
|
||||
from finetuning_types import (
|
||||
Checkpoint,
|
||||
Dataset,
|
||||
DoraFinetuningConfig,
|
||||
FinetuningAlgorithm,
|
||||
FinetuningJobLogStream,
|
||||
FinetuningJobStatus,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QLoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
from model_types import (
|
||||
BuiltinTool,
|
||||
Content,
|
||||
|
@ -42,6 +29,21 @@ from model_types import (
|
|||
URL,
|
||||
)
|
||||
|
||||
from post_training_types import (
|
||||
Checkpoint,
|
||||
Dataset,
|
||||
DoraFinetuningConfig,
|
||||
DPOAlignmentConfig,
|
||||
FinetuningAlgorithm,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
PostTrainingJobLogStream,
|
||||
PostTrainingJobStatus,
|
||||
QLoraFinetuningConfig,
|
||||
RLHFAlgorithm,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
from pyopenapi import Info, Options, Server, Specification, webmethod
|
||||
from strong_typing.schema import json_schema_type
|
||||
|
||||
|
@ -408,7 +410,7 @@ class Datasets(Protocol):
|
|||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class FinetuningTrainRequest:
|
||||
class PostTrainingSFTRequest:
|
||||
"""Request to finetune a model."""
|
||||
|
||||
job_uuid: str
|
||||
|
@ -432,11 +434,34 @@ class FinetuningTrainRequest:
|
|||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class FinetuningJobStatusResponse:
|
||||
class PostTrainingRLHFRequest:
|
||||
"""Request to finetune a model."""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
finetuned_model: URL
|
||||
|
||||
dataset: Dataset
|
||||
validation_dataset: Dataset
|
||||
|
||||
algorithm: RLHFAlgorithm
|
||||
algorithm_config: Union[DPOAlignmentConfig]
|
||||
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: Dict[str, Any]
|
||||
logger_config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class PostTrainingJobStatusResponse:
|
||||
"""Status of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
status: FinetuningJobStatus
|
||||
status: PostTrainingJobStatus
|
||||
|
||||
scheduled_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
|
@ -449,7 +474,7 @@ class FinetuningJobStatusResponse:
|
|||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class FinetuningJobArtifactsResponse:
|
||||
class PostTrainingJobArtifactsResponse:
|
||||
"""Artifacts of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
|
@ -458,27 +483,35 @@ class FinetuningJobArtifactsResponse:
|
|||
# TODO(ashwin): metrics, evals
|
||||
|
||||
|
||||
class Finetuning(Protocol):
|
||||
@webmethod(route="/finetuning/text_generation/train")
|
||||
def post_train(
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post_training/supervised_fine_tune/")
|
||||
def post_supervised_fine_tune(
|
||||
self,
|
||||
request: FinetuningTrainRequest,
|
||||
request: PostTrainingSFTRequest,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/post_training/preference_optimize/")
|
||||
def post_preference_optimize(
|
||||
self,
|
||||
request: PostTrainingRLHFRequest,
|
||||
) -> None: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/finetuning/job/logs")
|
||||
def get_training_log_stream(self, job_uuid: str) -> FinetuningJobLogStream: ...
|
||||
@webmethod(route="/post_training/job/logs")
|
||||
def get_training_log_stream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||
|
||||
@webmethod(route="/finetuning/job/status")
|
||||
def get_training_job_status(self, job_uuid: str) -> FinetuningJobStatusResponse: ...
|
||||
@webmethod(route="/post_training/job/status")
|
||||
def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/finetuning/job/cancel")
|
||||
@webmethod(route="/post_training/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/finetuning/job/artifacts")
|
||||
@webmethod(route="/post_training/job/artifacts")
|
||||
def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> FinetuningJobArtifactsResponse: ...
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
||||
|
||||
class LlamaStackEndpoints(
|
||||
|
@ -487,7 +520,7 @@ class LlamaStackEndpoints(
|
|||
RewardScoring,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Finetuning,
|
||||
PostTraining,
|
||||
MemoryBanks,
|
||||
): ...
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue