fine tuning apis

This commit is contained in:
Hardik Shah 2024-06-26 20:37:22 -07:00
parent 157e5ddf2e
commit 2478b76fbc
2 changed files with 400 additions and 0 deletions

266
fine_tuning.yaml Normal file
View file

@ -0,0 +1,266 @@
openapi: 3.0.0
info:
title: Fine Tuning API
version: 1.0.0
description: API for managing fine tuning jobs for machine learning models.
paths:
/fine_tuning/jobs/submit:
post:
summary: Submit a fine tuning job
description: Submit a fine tuning job with the specified configuration.
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/Config'
responses:
200:
description: Successfully submitted the fine tuning job.
content:
application/json:
schema:
$ref: '#/components/schemas/FineTuningJob'
/fine_tuning/jobs/status:
get:
summary: Gets last N fine tuning jobs
description: Retrieve the status of the last N fine tuning jobs based on the provided job ID.
parameters:
- in: query
name: job_id
schema:
type: string
required: true
description: The ID of the job to retrieve status for.
responses:
200:
description: Successfully retrieved the job status.
content:
application/json:
schema:
$ref: '#/components/schemas/FineTuningJob'
/fine_tuning/jobs/cancel:
post:
summary: Cancel provided job
description: Cancel the fine tuning job with the specified job ID.
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
job_id:
type: string
responses:
200:
description: Successfully cancelled the fine tuning job.
content:
application/json:
schema:
$ref: '#/components/schemas/FineTuningJob'
/fine_tuning/jobs/tail:
get:
summary: Tail logs of a particular job
description: Stream the logs of a particular job in real-time. This endpoint supports streaming responses.
parameters:
- in: query
name: job_id
schema:
type: string
required: true
description: The ID of the job to tail logs for.
responses:
200:
description: Streaming logs in real-time.
content:
application/x-ndjson:
schema:
type: object
properties:
logs:
type: array
items:
$ref: '#/components/schemas/Log'
headers:
Content-Type:
schema:
type: string
default: 'application/x-ndjson'
Transfer-Encoding:
schema:
type: string
default: 'chunked'
components:
schemas:
Message:
# keep in sync with /chat_completion
TrainingDataItem:
type: object
properties:
dialog:
type: array
items:
$ref: '#/components/schemas/Message'
keep_loss:
type: array
items:
type: boolean
WandBLogger:
type: object
properties:
project:
type: string
description: The project name in WandB where logs will be stored.
DiskLogger:
type: object
properties:
filename:
type: string
description: The filename where logs will be stored on disk.
FullFineTuneOptions:
type: object
properties:
enable_activation_checkpointing:
type: boolean
default: true
memory_efficient_fsdp_wrap:
type: boolean
default: true
fsdp_cpu_offload:
type: boolean
default: true
LoraFineTuneOptions:
type: object
properties:
lora_attn_modules:
type: array
items:
type: string
apply_lora_to_mlp:
type: boolean
default: false
apply_lora_to_output:
type: boolean
default: false
lora_rank:
type: integer
lora_alpha:
type: integer
FineTuningOptions:
type: object
properties:
n_epochs:
type: integer
batch_size:
type: integer
lr:
type: number
format: float
gradient_accumulation_steps:
type: integer
seed:
type: integer
shuffle:
type: boolean
custom_training_options:
oneOf:
- $ref: '#/components/schemas/FullFineTuneOptions'
- $ref: '#/components/schemas/LoraFineTuneOptions'
discriminator:
propertyName: finetuning_type
extras:
# json to put other config overrides that are required by torchtune
type: object
additionalProperties: true
Config:
type: object
properties:
model:
type: string
description: The model identifier that you want to fine tune.
data:
type: string
format: uri
description: Path to the JSONL file with each row representing a TrainingDataItem.
validation_data:
type: string
format: uri
description: Path to the JSONL file used for validation metrics.
fine_tuning_options:
$ref: '#/components/schemas/FineTuningOptions'
logger:
oneOf:
- $ref: '#/components/schemas/DiskLogger'
- $ref: '#/components/schemas/WandBLogger'
discriminator:
propertyName: log_type
overrides:
# eg. --nproc_per_node 4 instead of default that we need to pass through to torchrun
# when running locally
type: string
description: Custom override options for the fine tuning process.
metadata:
type: object
additionalProperties: true
FineTuningJob:
type: object
properties:
job_id:
type: string
description: Unique identifier for the fine tuning job.
created:
type: string
format: date-time
description: The creation date and time of the job.
finished_at:
type: string
format: date-time
description: The completion date and time of the job.
status:
type: string
enum: [validation, queued, running, failed, success, cancelled]
description: The current status of the job.
error_path:
type: string
format: uri
description: Path to the error log file.
checkpoints:
type: array
items:
type: string
format: uri
description: List of paths to checkpoint files for various epochs.
logs:
type: string
format: uri
description: Path to the logs, either local or a WandB URI.
input_config:
$ref: '#/components/schemas/Config'
metadata:
type: object
additionalProperties: true
Log:
type: object
properties:
message:
type: string
description: The log message.
timestamp:
type: string
format: date-time
description: The timestamp of the log message.

134
simple_view/fine_tuning.yml Normal file
View file

@ -0,0 +1,134 @@
# Fine Tuning APIs
== Schema ==
TrainingDataItem:
dialog: List[Message]
keep_loss: List[bool]
WandBLogger:
project: str
DiskLogger:
# log_dir will be pre-configured in environment
filename: str
FullFineTuneOptions:
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True
fsdp_cpu_offload: True
LoraFineTuneOptions:
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
FineTuningOptions:
n_epochs: int
batch_size: int
lr: float
gradient_accumulation_steps: int
seed: int
shuffle: bool
# Unions in OpenAPI with a reference field that can help disambiguate
custom_training_options:
discriminator:
propertyName: fine_tuning_type
mapping:
fft: FullFineTuneOptions
lora: LoraFineTuneOptions
# other options that can be passed in
extras: json
Config:
model: str # model that you want to fine tune
data: Path # jsonl with each row representing a TrainingDataItem
validation_data: Path # same as data but to get validation metrics on
# fine tuning args
fine_tuning_options: FineTuningOptions
# metric logging
logger:
discriminator:
propertyName: log_type
mapping:
disk: DiskLogger
wandb: WandBLogger
# Override options
# eg. --nproc_per_node 4 insted of defaults,
# this might be impl specific and can allow for various customizations
overrides: str
metadata: json # to carry over to job details
FineTuningJob:
job_id: str
created: str # format date-time
finished_at: str # format date-time
status: str # enum - validation, queued, running, failed, success, cancelled
error_path: Path # error logging
checkpoints: List[Path] # checkpoints for various epochs
logs: Path # local path / wandb uri
input_config: Config # config used to submit this job
metadata: json # carried over rom user provided input
Log:
message: string # The log message.
timestamp: string # format: date-time
== Callsites ==
callsite:
/fine_tuning/jobs/submit
request_type:
post
description:
Submit a fine tuning job
request:
config: Config
response:
fine_tuning_job: FineTuningJob
callsite:
/fine_tuning/jobs/status
request_type:
get
description:
Gets last N fine tuning jobs
request:
job_id: str
response:
fine_tuning_job: FineTuningJob
callsite:
/fine_tuning/jobs/cancel
request_type:
post
description:
Cancel provided job
request:
job_id: str
response:
fine_tuning_job: FineTuningJob
callsite:
/fine_tuning/jobs/tail
request_type:
get
description:
Tail logs of a particular job
request:
job_id: str
response:
logs: List[Log]
streaming:
enabled: True
chunkSize: 1024