From 631328f5560c2e0fa84a01f97a8e5a542187a492 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 11 Jul 2024 00:01:58 -0700 Subject: [PATCH] added DPO --- source/api_definitions.py | 91 ++- source/openapi.html | 507 +++++++++----- source/openapi.yaml | 651 ++++++++++-------- ...tuning_types.py => post_training_types.py} | 19 +- 4 files changed, 796 insertions(+), 472 deletions(-) rename source/{finetuning_types.py => post_training_types.py} (86%) diff --git a/source/api_definitions.py b/source/api_definitions.py index b6283c83b..e13e21873 100644 --- a/source/api_definitions.py +++ b/source/api_definitions.py @@ -12,19 +12,6 @@ from agentic_system_types import ( SafetyViolation, ) -from finetuning_types import ( - Checkpoint, - Dataset, - DoraFinetuningConfig, - FinetuningAlgorithm, - FinetuningJobLogStream, - FinetuningJobStatus, - LoraFinetuningConfig, - OptimizerConfig, - QLoraFinetuningConfig, - TrainingConfig, -) - from model_types import ( BuiltinTool, Content, @@ -42,6 +29,21 @@ from model_types import ( URL, ) +from post_training_types import ( + Checkpoint, + Dataset, + DoraFinetuningConfig, + DPOAlignmentConfig, + FinetuningAlgorithm, + LoraFinetuningConfig, + OptimizerConfig, + PostTrainingJobLogStream, + PostTrainingJobStatus, + QLoraFinetuningConfig, + RLHFAlgorithm, + TrainingConfig, +) + from pyopenapi import Info, Options, Server, Specification, webmethod from strong_typing.schema import json_schema_type @@ -408,7 +410,7 @@ class Datasets(Protocol): @json_schema_type @dataclass -class FinetuningTrainRequest: +class PostTrainingSFTRequest: """Request to finetune a model.""" job_uuid: str @@ -432,11 +434,34 @@ class FinetuningTrainRequest: @json_schema_type @dataclass -class FinetuningJobStatusResponse: +class PostTrainingRLHFRequest: + """Request to finetune a model.""" + + job_uuid: str + + finetuned_model: URL + + dataset: Dataset + validation_dataset: Dataset + + algorithm: RLHFAlgorithm + algorithm_config: Union[DPOAlignmentConfig] + + optimizer_config: OptimizerConfig + training_config: TrainingConfig + + # TODO: define these + hyperparam_search_config: Dict[str, Any] + logger_config: Dict[str, Any] + + +@json_schema_type +@dataclass +class PostTrainingJobStatusResponse: """Status of a finetuning job.""" job_uuid: str - status: FinetuningJobStatus + status: PostTrainingJobStatus scheduled_at: Optional[datetime] = None started_at: Optional[datetime] = None @@ -449,7 +474,7 @@ class FinetuningJobStatusResponse: @json_schema_type @dataclass -class FinetuningJobArtifactsResponse: +class PostTrainingJobArtifactsResponse: """Artifacts of a finetuning job.""" job_uuid: str @@ -458,27 +483,35 @@ class FinetuningJobArtifactsResponse: # TODO(ashwin): metrics, evals -class Finetuning(Protocol): - @webmethod(route="/finetuning/text_generation/train") - def post_train( +class PostTraining(Protocol): + @webmethod(route="/post_training/supervised_fine_tune/") + def post_supervised_fine_tune( self, - request: FinetuningTrainRequest, + request: PostTrainingSFTRequest, + ) -> None: ... + + @webmethod(route="/post_training/preference_optimize/") + def post_preference_optimize( + self, + request: PostTrainingRLHFRequest, ) -> None: ... # sends SSE stream of logs - @webmethod(route="/finetuning/job/logs") - def get_training_log_stream(self, job_uuid: str) -> FinetuningJobLogStream: ... + @webmethod(route="/post_training/job/logs") + def get_training_log_stream(self, job_uuid: str) -> PostTrainingJobLogStream: ... - @webmethod(route="/finetuning/job/status") - def get_training_job_status(self, job_uuid: str) -> FinetuningJobStatusResponse: ... + @webmethod(route="/post_training/job/status") + def get_training_job_status( + self, job_uuid: str + ) -> PostTrainingJobStatusResponse: ... - @webmethod(route="/finetuning/job/cancel") + @webmethod(route="/post_training/job/cancel") def cancel_training_job(self, job_uuid: str) -> None: ... - @webmethod(route="/finetuning/job/artifacts") + @webmethod(route="/post_training/job/artifacts") def get_training_job_artifacts( self, job_uuid: str - ) -> FinetuningJobArtifactsResponse: ... + ) -> PostTrainingJobArtifactsResponse: ... class LlamaStackEndpoints( @@ -487,7 +520,7 @@ class LlamaStackEndpoints( RewardScoring, SyntheticDataGeneration, Datasets, - Finetuning, + PostTraining, MemoryBanks, ): ... diff --git a/source/openapi.html b/source/openapi.html index b61378bad..413e73a21 100644 --- a/source/openapi.html +++ b/source/openapi.html @@ -299,7 +299,7 @@ "parameters": [] } }, - "/finetuning/job/artifacts": { + "/post_training/job/artifacts": { "get": { "responses": { "200": { @@ -307,14 +307,14 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/FinetuningJobArtifactsResponse" + "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" } } } } }, "tags": [ - "Finetuning" + "PostTraining" ], "parameters": [ { @@ -328,7 +328,7 @@ ] } }, - "/finetuning/job/status": { + "/post_training/job/status": { "get": { "responses": { "200": { @@ -336,14 +336,14 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/FinetuningJobStatusResponse" + "$ref": "#/components/schemas/PostTrainingJobStatusResponse" } } } } }, "tags": [ - "Finetuning" + "PostTraining" ], "parameters": [ { @@ -357,7 +357,7 @@ ] } }, - "/finetuning/job/logs": { + "/post_training/job/logs": { "get": { "responses": { "200": { @@ -365,14 +365,14 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/FinetuningJobLogStream" + "$ref": "#/components/schemas/PostTrainingJobLogStream" } } } } }, "tags": [ - "Finetuning" + "PostTraining" ], "parameters": [ { @@ -664,6 +664,29 @@ } } }, + "/post_training/preference_optimize/": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "PostTraining" + ], + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PostTrainingRLHFRequest" + } + } + }, + "required": true + } + } + }, "/reward_scoring/score": { "post": { "responses": { @@ -694,7 +717,7 @@ } } }, - "/finetuning/text_generation/train": { + "/post_training/supervised_fine_tune/": { "post": { "responses": { "200": { @@ -702,14 +725,14 @@ } }, "tags": [ - "Finetuning" + "PostTraining" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/FinetuningTrainRequest" + "$ref": "#/components/schemas/PostTrainingSFTRequest" } } }, @@ -1697,7 +1720,7 @@ "name" ] }, - "FinetuningJobArtifactsResponse": { + "PostTrainingJobArtifactsResponse": { "type": "object", "properties": { "job_uuid": { @@ -1730,7 +1753,7 @@ ], "title": "Artifacts of a finetuning job." }, - "FinetuningJobStatusResponse": { + "PostTrainingJobStatusResponse": { "type": "object", "properties": { "job_uuid": { @@ -1810,7 +1833,7 @@ ], "title": "Status of a finetuning job." }, - "FinetuningJobLogStream": { + "PostTrainingJobLogStream": { "type": "object", "properties": { "job_uuid": { @@ -2672,6 +2695,191 @@ ], "title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." }, + "DPOAlignmentConfig": { + "type": "object", + "properties": { + "reward_scale": { + "type": "number" + }, + "reward_clip": { + "type": "number" + }, + "epsilon": { + "type": "number" + }, + "gamma": { + "type": "number" + } + }, + "additionalProperties": false, + "required": [ + "reward_scale", + "reward_clip", + "epsilon", + "gamma" + ] + }, + "OptimizerConfig": { + "type": "object", + "properties": { + "optimizer_type": { + "type": "string", + "enum": [ + "adam", + "adamw", + "sgd" + ] + }, + "lr": { + "type": "number" + }, + "lr_min": { + "type": "number" + }, + "weight_decay": { + "type": "number" + } + }, + "additionalProperties": false, + "required": [ + "optimizer_type", + "lr", + "lr_min", + "weight_decay" + ] + }, + "PostTrainingRLHFRequest": { + "type": "object", + "properties": { + "job_uuid": { + "type": "string" + }, + "finetuned_model": { + "$ref": "#/components/schemas/URL" + }, + "dataset": { + "$ref": "#/components/schemas/Dataset" + }, + "validation_dataset": { + "$ref": "#/components/schemas/Dataset" + }, + "algorithm": { + "type": "string", + "enum": [ + "dpo" + ] + }, + "algorithm_config": { + "$ref": "#/components/schemas/DPOAlignmentConfig" + }, + "optimizer_config": { + "$ref": "#/components/schemas/OptimizerConfig" + }, + "training_config": { + "$ref": "#/components/schemas/TrainingConfig" + }, + "hyperparam_search_config": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "logger_config": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "job_uuid", + "finetuned_model", + "dataset", + "validation_dataset", + "algorithm", + "algorithm_config", + "optimizer_config", + "training_config", + "hyperparam_search_config", + "logger_config" + ], + "title": "Request to finetune a model." + }, + "TrainingConfig": { + "type": "object", + "properties": { + "n_epochs": { + "type": "integer" + }, + "batch_size": { + "type": "integer" + }, + "shuffle": { + "type": "boolean" + }, + "n_iters": { + "type": "integer" + }, + "enable_activation_checkpointing": { + "type": "boolean" + }, + "memory_efficient_fsdp_wrap": { + "type": "boolean" + }, + "fsdp_cpu_offload": { + "type": "boolean" + } + }, + "additionalProperties": false, + "required": [ + "n_epochs", + "batch_size", + "shuffle", + "n_iters", + "enable_activation_checkpointing", + "memory_efficient_fsdp_wrap", + "fsdp_cpu_offload" + ] + }, "RewardScoringRequest": { "type": "object", "properties": { @@ -2727,7 +2935,69 @@ ], "title": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold." }, - "FinetuningTrainRequest": { + "DoraFinetuningConfig": { + "type": "object", + "properties": { + "lora_attn_modules": { + "type": "array", + "items": { + "type": "string" + } + }, + "apply_lora_to_mlp": { + "type": "boolean" + }, + "apply_lora_to_output": { + "type": "boolean" + }, + "rank": { + "type": "integer" + }, + "alpha": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "lora_attn_modules", + "apply_lora_to_mlp", + "apply_lora_to_output", + "rank", + "alpha" + ] + }, + "LoraFinetuningConfig": { + "type": "object", + "properties": { + "lora_attn_modules": { + "type": "array", + "items": { + "type": "string" + } + }, + "apply_lora_to_mlp": { + "type": "boolean" + }, + "apply_lora_to_output": { + "type": "boolean" + }, + "rank": { + "type": "integer" + }, + "alpha": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "lora_attn_modules", + "apply_lora_to_mlp", + "apply_lora_to_output", + "rank", + "alpha" + ] + }, + "PostTrainingSFTRequest": { "type": "object", "properties": { "job_uuid": { @@ -2761,66 +3031,10 @@ "$ref": "#/components/schemas/LoraFinetuningConfig" }, { - "type": "object", - "properties": { - "lora_attn_modules": { - "type": "array", - "items": { - "type": "string" - } - }, - "apply_lora_to_mlp": { - "type": "boolean" - }, - "apply_lora_to_output": { - "type": "boolean" - }, - "rank": { - "type": "integer" - }, - "alpha": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "lora_attn_modules", - "apply_lora_to_mlp", - "apply_lora_to_output", - "rank", - "alpha" - ] + "$ref": "#/components/schemas/QLoraFinetuningConfig" }, { - "type": "object", - "properties": { - "lora_attn_modules": { - "type": "array", - "items": { - "type": "string" - } - }, - "apply_lora_to_mlp": { - "type": "boolean" - }, - "apply_lora_to_output": { - "type": "boolean" - }, - "rank": { - "type": "integer" - }, - "alpha": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "lora_attn_modules", - "apply_lora_to_mlp", - "apply_lora_to_output", - "rank", - "alpha" - ] + "$ref": "#/components/schemas/DoraFinetuningConfig" } ] }, @@ -2896,7 +3110,7 @@ ], "title": "Request to finetune a model." }, - "LoraFinetuningConfig": { + "QLoraFinetuningConfig": { "type": "object", "properties": { "lora_attn_modules": { @@ -2926,71 +3140,6 @@ "rank", "alpha" ] - }, - "OptimizerConfig": { - "type": "object", - "properties": { - "optimizer_type": { - "type": "string", - "enum": [ - "adam", - "adamw", - "sgd" - ] - }, - "lr": { - "type": "number" - }, - "lr_min": { - "type": "number" - }, - "weight_decay": { - "type": "number" - } - }, - "additionalProperties": false, - "required": [ - "optimizer_type", - "lr", - "lr_min", - "weight_decay" - ] - }, - "TrainingConfig": { - "type": "object", - "properties": { - "n_epochs": { - "type": "integer" - }, - "batch_size": { - "type": "integer" - }, - "shuffle": { - "type": "boolean" - }, - "n_iters": { - "type": "integer" - }, - "enable_activation_checkpointing": { - "type": "boolean" - }, - "memory_efficient_fsdp_wrap": { - "type": "boolean" - }, - "fsdp_cpu_offload": { - "type": "boolean" - } - }, - "additionalProperties": false, - "required": [ - "n_epochs", - "batch_size", - "shuffle", - "n_iters", - "enable_activation_checkpointing", - "memory_efficient_fsdp_wrap", - "fsdp_cpu_offload" - ] } }, "responses": {} @@ -3001,27 +3150,27 @@ } ], "tags": [ - { - "name": "RewardScoring" - }, - { - "name": "MemoryBanks" - }, - { - "name": "SyntheticDataGeneration" - }, - { - "name": "Finetuning" - }, { "name": "AgenticSystem" }, + { + "name": "RewardScoring" + }, { "name": "Inference" }, + { + "name": "SyntheticDataGeneration" + }, { "name": "Datasets" }, + { + "name": "PostTraining" + }, + { + "name": "MemoryBanks" + }, { "name": "ShieldConfig", "description": "" @@ -3075,16 +3224,16 @@ "description": "" }, { - "name": "FinetuningJobArtifactsResponse", - "description": "Artifacts of a finetuning job.\n\n" + "name": "PostTrainingJobArtifactsResponse", + "description": "Artifacts of a finetuning job.\n\n" }, { - "name": "FinetuningJobStatusResponse", - "description": "Status of a finetuning job.\n\n" + "name": "PostTrainingJobStatusResponse", + "description": "Status of a finetuning job.\n\n" }, { - "name": "FinetuningJobLogStream", - "description": "Stream of logs from a finetuning job.\n\n" + "name": "PostTrainingJobLogStream", + "description": "Stream of logs from a finetuning job.\n\n" }, { "name": "BatchChatCompletionRequest", @@ -3138,6 +3287,22 @@ "name": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" }, + { + "name": "DPOAlignmentConfig", + "description": "" + }, + { + "name": "OptimizerConfig", + "description": "" + }, + { + "name": "PostTrainingRLHFRequest", + "description": "Request to finetune a model.\n\n" + }, + { + "name": "TrainingConfig", + "description": "" + }, { "name": "RewardScoringRequest", "description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n" @@ -3147,20 +3312,20 @@ "description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" }, { - "name": "FinetuningTrainRequest", - "description": "Request to finetune a model.\n\n" + "name": "DoraFinetuningConfig", + "description": "" }, { "name": "LoraFinetuningConfig", "description": "" }, { - "name": "OptimizerConfig", - "description": "" + "name": "PostTrainingSFTRequest", + "description": "Request to finetune a model.\n\n" }, { - "name": "TrainingConfig", - "description": "" + "name": "QLoraFinetuningConfig", + "description": "" } ], "x-tagGroups": [ @@ -3169,9 +3334,9 @@ "tags": [ "AgenticSystem", "Datasets", - "Finetuning", "Inference", "MemoryBanks", + "PostTraining", "RewardScoring", "SyntheticDataGeneration" ] @@ -3195,18 +3360,22 @@ "CompletionResponse", "CompletionResponseStreamChunk", "CreateDatasetRequest", + "DPOAlignmentConfig", "Dataset", "Dialog", - "FinetuningJobArtifactsResponse", - "FinetuningJobLogStream", - "FinetuningJobStatusResponse", - "FinetuningTrainRequest", + "DoraFinetuningConfig", "KScoredPromptGenerations", "LoraFinetuningConfig", "MemoryBank", "Message", "MessageScore", "OptimizerConfig", + "PostTrainingJobArtifactsResponse", + "PostTrainingJobLogStream", + "PostTrainingJobStatusResponse", + "PostTrainingRLHFRequest", + "PostTrainingSFTRequest", + "QLoraFinetuningConfig", "RewardScoringRequest", "RewardScoringResponse", "ShieldConfig", diff --git a/source/openapi.yaml b/source/openapi.yaml index da53c4a56..b8d3b0285 100644 --- a/source/openapi.yaml +++ b/source/openapi.yaml @@ -879,6 +879,23 @@ components: - dataset title: Request to create a dataset. type: object + DPOAlignmentConfig: + additionalProperties: false + properties: + epsilon: + type: number + gamma: + type: number + reward_clip: + type: number + reward_scale: + type: number + required: + - reward_scale + - reward_clip + - epsilon + - gamma + type: object Dataset: additionalProperties: false properties: @@ -923,195 +940,27 @@ components: - message - message_history type: object - FinetuningJobArtifactsResponse: + DoraFinetuningConfig: additionalProperties: false properties: - checkpoints: - items: - additionalProperties: false - properties: - iters: - type: integer - path: - $ref: '#/components/schemas/URL' - required: - - iters - - path - type: object - type: array - job_uuid: - type: string - required: - - job_uuid - - checkpoints - title: Artifacts of a finetuning job. - type: object - FinetuningJobLogStream: - additionalProperties: false - properties: - job_uuid: - type: string - log_lines: + alpha: + type: integer + apply_lora_to_mlp: + type: boolean + apply_lora_to_output: + type: boolean + lora_attn_modules: items: type: string type: array + rank: + type: integer required: - - job_uuid - - log_lines - title: Stream of logs from a finetuning job. - type: object - FinetuningJobStatusResponse: - additionalProperties: false - properties: - checkpoints: - items: - additionalProperties: false - properties: - iters: - type: integer - path: - $ref: '#/components/schemas/URL' - required: - - iters - - path - type: object - type: array - completed_at: - format: date-time - type: string - job_uuid: - type: string - resources_allocated: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - scheduled_at: - format: date-time - type: string - started_at: - format: date-time - type: string - status: - enum: - - running - - completed - - failed - - scheduled - type: string - required: - - job_uuid - - status - - checkpoints - title: Status of a finetuning job. - type: object - FinetuningTrainRequest: - additionalProperties: false - properties: - algorithm: - enum: - - full - - lora - - qlora - - dora - type: string - algorithm_config: - oneOf: - - $ref: '#/components/schemas/LoraFinetuningConfig' - - additionalProperties: false - properties: - alpha: - type: integer - apply_lora_to_mlp: - type: boolean - apply_lora_to_output: - type: boolean - lora_attn_modules: - items: - type: string - type: array - rank: - type: integer - required: - - lora_attn_modules - - apply_lora_to_mlp - - apply_lora_to_output - - rank - - alpha - type: object - - additionalProperties: false - properties: - alpha: - type: integer - apply_lora_to_mlp: - type: boolean - apply_lora_to_output: - type: boolean - lora_attn_modules: - items: - type: string - type: array - rank: - type: integer - required: - - lora_attn_modules - - apply_lora_to_mlp - - apply_lora_to_output - - rank - - alpha - type: object - dataset: - $ref: '#/components/schemas/Dataset' - hyperparam_search_config: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - job_uuid: - type: string - logger_config: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - model: - enum: - - llama3_8b - - llama3_70b - type: string - optimizer_config: - $ref: '#/components/schemas/OptimizerConfig' - training_config: - $ref: '#/components/schemas/TrainingConfig' - validation_dataset: - $ref: '#/components/schemas/Dataset' - required: - - job_uuid - - model - - dataset - - validation_dataset - - algorithm - - algorithm_config - - optimizer_config - - training_config - - hyperparam_search_config - - logger_config - title: Request to finetune a model. + - lora_attn_modules + - apply_lora_to_mlp + - apply_lora_to_output + - rank + - alpha type: object KScoredPromptGenerations: additionalProperties: false @@ -1259,6 +1108,232 @@ components: - lr_min - weight_decay type: object + PostTrainingJobArtifactsResponse: + additionalProperties: false + properties: + checkpoints: + items: + additionalProperties: false + properties: + iters: + type: integer + path: + $ref: '#/components/schemas/URL' + required: + - iters + - path + type: object + type: array + job_uuid: + type: string + required: + - job_uuid + - checkpoints + title: Artifacts of a finetuning job. + type: object + PostTrainingJobLogStream: + additionalProperties: false + properties: + job_uuid: + type: string + log_lines: + items: + type: string + type: array + required: + - job_uuid + - log_lines + title: Stream of logs from a finetuning job. + type: object + PostTrainingJobStatusResponse: + additionalProperties: false + properties: + checkpoints: + items: + additionalProperties: false + properties: + iters: + type: integer + path: + $ref: '#/components/schemas/URL' + required: + - iters + - path + type: object + type: array + completed_at: + format: date-time + type: string + job_uuid: + type: string + resources_allocated: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + scheduled_at: + format: date-time + type: string + started_at: + format: date-time + type: string + status: + enum: + - running + - completed + - failed + - scheduled + type: string + required: + - job_uuid + - status + - checkpoints + title: Status of a finetuning job. + type: object + PostTrainingRLHFRequest: + additionalProperties: false + properties: + algorithm: + enum: + - dpo + type: string + algorithm_config: + $ref: '#/components/schemas/DPOAlignmentConfig' + dataset: + $ref: '#/components/schemas/Dataset' + finetuned_model: + $ref: '#/components/schemas/URL' + hyperparam_search_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + job_uuid: + type: string + logger_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + optimizer_config: + $ref: '#/components/schemas/OptimizerConfig' + training_config: + $ref: '#/components/schemas/TrainingConfig' + validation_dataset: + $ref: '#/components/schemas/Dataset' + required: + - job_uuid + - finetuned_model + - dataset + - validation_dataset + - algorithm + - algorithm_config + - optimizer_config + - training_config + - hyperparam_search_config + - logger_config + title: Request to finetune a model. + type: object + PostTrainingSFTRequest: + additionalProperties: false + properties: + algorithm: + enum: + - full + - lora + - qlora + - dora + type: string + algorithm_config: + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QLoraFinetuningConfig' + - $ref: '#/components/schemas/DoraFinetuningConfig' + dataset: + $ref: '#/components/schemas/Dataset' + hyperparam_search_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + job_uuid: + type: string + logger_config: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + model: + enum: + - llama3_8b + - llama3_70b + type: string + optimizer_config: + $ref: '#/components/schemas/OptimizerConfig' + training_config: + $ref: '#/components/schemas/TrainingConfig' + validation_dataset: + $ref: '#/components/schemas/Dataset' + required: + - job_uuid + - model + - dataset + - validation_dataset + - algorithm + - algorithm_config + - optimizer_config + - training_config + - hyperparam_search_config + - logger_config + title: Request to finetune a model. + type: object + QLoraFinetuningConfig: + additionalProperties: false + properties: + alpha: + type: integer + apply_lora_to_mlp: + type: boolean + apply_lora_to_output: + type: boolean + lora_attn_modules: + items: + type: string + type: array + rank: + type: integer + required: + - lora_attn_modules + - apply_lora_to_mlp + - apply_lora_to_output + - rank + - alpha + type: object RewardScoringRequest: additionalProperties: false properties: @@ -1581,71 +1656,6 @@ paths: description: OK tags: - Datasets - /finetuning/job/artifacts: - get: - parameters: - - in: query - name: job_uuid - required: true - schema: - type: string - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/FinetuningJobArtifactsResponse' - description: OK - tags: - - Finetuning - /finetuning/job/logs: - get: - parameters: - - in: query - name: job_uuid - required: true - schema: - type: string - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/FinetuningJobLogStream' - description: OK - tags: - - Finetuning - /finetuning/job/status: - get: - parameters: - - in: query - name: job_uuid - required: true - schema: - type: string - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/FinetuningJobStatusResponse' - description: OK - tags: - - Finetuning - /finetuning/text_generation/train: - post: - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/FinetuningTrainRequest' - required: true - responses: - '200': - description: OK - tags: - - Finetuning /memory_banks/create: post: parameters: @@ -1787,6 +1797,85 @@ paths: description: OK tags: - MemoryBanks + /post_training/job/artifacts: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' + description: OK + tags: + - PostTraining + /post_training/job/logs: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJobLogStream' + description: OK + tags: + - PostTraining + /post_training/job/status: + get: + parameters: + - in: query + name: job_uuid + required: true + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingJobStatusResponse' + description: OK + tags: + - PostTraining + /post_training/preference_optimize/: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingRLHFRequest' + required: true + responses: + '200': + description: OK + tags: + - PostTraining + /post_training/supervised_fine_tune/: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PostTrainingSFTRequest' + required: true + responses: + '200': + description: OK + tags: + - PostTraining /reward_scoring/score: post: parameters: [] @@ -1828,13 +1917,13 @@ security: servers: - url: http://llama.meta.com tags: -- name: RewardScoring -- name: MemoryBanks -- name: SyntheticDataGeneration -- name: Finetuning - name: AgenticSystem +- name: RewardScoring - name: Inference +- name: SyntheticDataGeneration - name: Datasets +- name: PostTraining +- name: MemoryBanks - description: name: ShieldConfig - description: ' - name: FinetuningJobArtifactsResponse + name: PostTrainingJobArtifactsResponse - description: 'Status of a finetuning job. - ' - name: FinetuningJobStatusResponse + name: PostTrainingJobStatusResponse - description: 'Stream of logs from a finetuning job. - ' - name: FinetuningJobLogStream + ' + name: PostTrainingJobLogStream - description: name: BatchChatCompletionRequest @@ -1961,6 +2050,19 @@ tags: ' name: SyntheticDataGenerationResponse +- description: + name: DPOAlignmentConfig +- description: + name: OptimizerConfig +- description: 'Request to finetune a model. + + + ' + name: PostTrainingRLHFRequest +- description: + name: TrainingConfig - description: 'Request to score a reward function. A list of prompts and a list of responses per prompt. @@ -1973,27 +2075,28 @@ tags: ' name: RewardScoringResponse -- description: 'Request to finetune a model. - - - ' - name: FinetuningTrainRequest +- description: + name: DoraFinetuningConfig - description: name: LoraFinetuningConfig -- description: ' + name: PostTrainingSFTRequest +- description: - name: OptimizerConfig -- description: - name: TrainingConfig + name: QLoraFinetuningConfig x-tagGroups: - name: Operations tags: - AgenticSystem - Datasets - - Finetuning - Inference - MemoryBanks + - PostTraining - RewardScoring - SyntheticDataGeneration - name: Types @@ -2014,18 +2117,22 @@ x-tagGroups: - CompletionResponse - CompletionResponseStreamChunk - CreateDatasetRequest + - DPOAlignmentConfig - Dataset - Dialog - - FinetuningJobArtifactsResponse - - FinetuningJobLogStream - - FinetuningJobStatusResponse - - FinetuningTrainRequest + - DoraFinetuningConfig - KScoredPromptGenerations - LoraFinetuningConfig - MemoryBank - Message - MessageScore - OptimizerConfig + - PostTrainingJobArtifactsResponse + - PostTrainingJobLogStream + - PostTrainingJobStatusResponse + - PostTrainingRLHFRequest + - PostTrainingSFTRequest + - QLoraFinetuningConfig - RewardScoringRequest - RewardScoringResponse - ShieldConfig diff --git a/source/finetuning_types.py b/source/post_training_types.py similarity index 86% rename from source/finetuning_types.py rename to source/post_training_types.py index 9aa897eac..f67fce4d8 100644 --- a/source/finetuning_types.py +++ b/source/post_training_types.py @@ -72,11 +72,13 @@ class LoraFinetuningConfig: alpha: int +@json_schema_type @dataclass class QLoraFinetuningConfig(LoraFinetuningConfig): pass +@json_schema_type @dataclass class DoraFinetuningConfig(LoraFinetuningConfig): pass @@ -84,14 +86,14 @@ class DoraFinetuningConfig(LoraFinetuningConfig): @json_schema_type @dataclass -class FinetuningJobLogStream: +class PostTrainingJobLogStream: """Stream of logs from a finetuning job.""" job_uuid: str log_lines: List[str] -class FinetuningJobStatus(Enum): +class PostTrainingJobStatus(Enum): running = "running" completed = "completed" failed = "failed" @@ -102,3 +104,16 @@ class FinetuningJobStatus(Enum): class Checkpoint: iters: int path: URL + + +class RLHFAlgorithm(Enum): + dpo = "dpo" + + +@json_schema_type +@dataclass +class DPOAlignmentConfig: + reward_scale: float + reward_clip: float + epsilon: float + gamma: float