llama-stack-mirror/src/llama_stack_api/post_training.py
Sébastien Han 97f535c4f1
Some checks failed
Pre-commit / pre-commit (push) Successful in 3m27s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test llama stack list-deps / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Test llama stack list-deps / show-single-provider (push) Successful in 25s
Test External API and Providers / test-external (venv) (push) Failing after 34s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
Test Llama Stack Build / build (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Successful in 48s
Test llama stack list-deps / list-deps-from-config (push) Successful in 52s
Test llama stack list-deps / list-deps (push) Failing after 52s
Python Package Build Test / build (3.13) (push) Failing after 1m2s
UI Tests / ui-tests (22) (push) Successful in 1m15s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 1m29s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 1m54s
Unit Tests / unit-tests (3.13) (push) Failing after 2m13s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m20s
feat(openapi): switch to fastapi-based generator (#3944)
# What does this PR do?
This replaces the legacy "pyopenapi + strong_typing" pipeline with a
FastAPI-backed generator that has an explicit schema registry inside
`llama_stack_api`. The key changes:

1. **New generator architecture.** FastAPI now builds the OpenAPI schema
directly from the real routes, while helper modules
(`schema_collection`, `endpoints`, `schema_transforms`, etc.)
post-process the result. The old pyopenapi stack and its strong_typing
helpers are removed entirely, so we no longer rely on fragile AST
analysis or top-level import side effects.

2. **Schema registry in `llama_stack_api`.** `schema_utils.py` keeps a
`SchemaInfo` record for every `@json_schema_type`, `register_schema`,
and dynamically created request model. The OpenAPI generator and other
tooling query this registry instead of scanning the package tree,
producing deterministic names (e.g., `{MethodName}Request`), capturing
all optional/nullable fields, and making schema discovery testable. A
new unit test covers the registry behavior.

3. **Regenerated specs + CI alignment.** All docs/Stainless specs are
regenerated from the new pipeline, so optional/nullable fields now match
reality (expect the API Conformance workflow to report breaking
changes—this PR establishes the new baseline). The workflow itself is
back to the stock oasdiff invocation so future regressions surface
normally.

*Conformance will be RED on this PR; we choose to accept the
deviations.*

## Test Plan
- `uv run pytest tests/unit/server/test_schema_registry.py`
- `uv run python -m scripts.openapi_generator.main docs/static`

---------

Signed-off-by: Sébastien Han <seb@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-11-14 15:53:53 -08:00

370 lines
13 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack_api.common.content_types import URL
from llama_stack_api.common.job_types import JobStatus
from llama_stack_api.common.training_types import Checkpoint
from llama_stack_api.schema_utils import json_schema_type, register_schema, webmethod
from llama_stack_api.version import LLAMA_STACK_API_V1ALPHA
@json_schema_type
class OptimizerType(Enum):
"""Available optimizer algorithms for training.
:cvar adam: Adaptive Moment Estimation optimizer
:cvar adamw: AdamW optimizer with weight decay
:cvar sgd: Stochastic Gradient Descent optimizer
"""
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
"""Format of the training dataset.
:cvar instruct: Instruction-following format with prompt and completion
:cvar dialog: Multi-turn conversation format with messages
"""
instruct = "instruct"
dialog = "dialog"
@json_schema_type
class DataConfig(BaseModel):
"""Configuration for training data and data loading.
:param dataset_id: Unique identifier for the training dataset
:param batch_size: Number of samples per training batch
:param shuffle: Whether to shuffle the dataset during training
:param data_format: Format of the dataset (instruct or dialog)
:param validation_dataset_id: (Optional) Unique identifier for the validation dataset
:param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency
:param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens
"""
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: str | None = None
packed: bool | None = False
train_on_input: bool | None = False
@json_schema_type
class OptimizerConfig(BaseModel):
"""Configuration parameters for the optimization algorithm.
:param optimizer_type: Type of optimizer to use (adam, adamw, or sgd)
:param lr: Learning rate for the optimizer
:param weight_decay: Weight decay coefficient for regularization
:param num_warmup_steps: Number of steps for learning rate warmup
"""
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
"""Configuration for memory and compute efficiency optimizations.
:param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage
:param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory
:param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping
:param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU
"""
enable_activation_checkpointing: bool | None = False
enable_activation_offloading: bool | None = False
memory_efficient_fsdp_wrap: bool | None = False
fsdp_cpu_offload: bool | None = False
@json_schema_type
class TrainingConfig(BaseModel):
"""Comprehensive configuration for the training process.
:param n_epochs: Number of training epochs to run
:param max_steps_per_epoch: Maximum number of steps to run per epoch
:param gradient_accumulation_steps: Number of steps to accumulate gradients before updating
:param max_validation_steps: (Optional) Maximum number of validation steps per epoch
:param data_config: (Optional) Configuration for data loading and formatting
:param optimizer_config: (Optional) Configuration for the optimization algorithm
:param efficiency_config: (Optional) Configuration for memory and compute optimizations
:param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32)
"""
n_epochs: int
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: int | None = 1
data_config: DataConfig | None = None
optimizer_config: OptimizerConfig | None = None
efficiency_config: EfficiencyConfig | None = None
dtype: str | None = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning.
:param type: Algorithm type identifier, always "LoRA"
:param lora_attn_modules: List of attention module names to apply LoRA to
:param apply_lora_to_mlp: Whether to apply LoRA to MLP layers
:param apply_lora_to_output: Whether to apply LoRA to output projection layers
:param rank: Rank of the LoRA adaptation (lower rank = fewer parameters)
:param alpha: LoRA scaling parameter that controls adaptation strength
:param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)
:param quantize_base: (Optional) Whether to quantize the base model weights
"""
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: list[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool | None = False
quantize_base: bool | None = False
@json_schema_type
class QATFinetuningConfig(BaseModel):
"""Configuration for Quantization-Aware Training (QAT) fine-tuning.
:param type: Algorithm type identifier, always "QAT"
:param quantizer_name: Name of the quantization algorithm to use
:param group_size: Size of groups for grouped quantization
"""
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job.
:param job_uuid: Unique identifier for the training job
:param log_lines: List of log message strings from the training process
"""
job_uuid: str
log_lines: list[str]
@json_schema_type
class RLHFAlgorithm(Enum):
"""Available reinforcement learning from human feedback algorithms.
:cvar dpo: Direct Preference Optimization algorithm
"""
dpo = "dpo"
@json_schema_type
class DPOLossType(Enum):
sigmoid = "sigmoid"
hinge = "hinge"
ipo = "ipo"
kto_pair = "kto_pair"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
"""Configuration for Direct Preference Optimization (DPO) alignment.
:param beta: Temperature parameter for the DPO loss
:param loss_type: The type of loss function to use for DPO
"""
beta: float
loss_type: DPOLossType = DPOLossType.sigmoid
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model using reinforcement learning from human feedback.
:param job_uuid: Unique identifier for the training job
:param finetuned_model: URL or path to the base model to fine-tune
:param dataset_id: Unique identifier for the training dataset
:param validation_dataset_id: Unique identifier for the validation dataset
:param algorithm: RLHF algorithm to use for training
:param algorithm_config: Configuration parameters for the RLHF algorithm
:param optimizer_config: Configuration parameters for the optimization algorithm
:param training_config: Configuration parameters for the training process
:param hyperparam_search_config: Configuration for hyperparameter search
:param logger_config: Configuration for training logging
"""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: dict[str, Any]
logger_config: dict[str, Any]
@json_schema_type
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job.
:param job_uuid: Unique identifier for the training job
:param status: Current status of the training job
:param scheduled_at: (Optional) Timestamp when the job was scheduled
:param started_at: (Optional) Timestamp when the job execution began
:param completed_at: (Optional) Timestamp when the job finished, if completed
:param resources_allocated: (Optional) Information about computational resources allocated to the job
:param checkpoints: List of model checkpoints created during training
"""
job_uuid: str
status: JobStatus
scheduled_at: datetime | None = None
started_at: datetime | None = None
completed_at: datetime | None = None
resources_allocated: dict[str, Any] | None = None
checkpoints: list[Checkpoint] = Field(default_factory=list)
@json_schema_type
class ListPostTrainingJobsResponse(BaseModel):
data: list[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job.
:param job_uuid: Unique identifier for the training job
:param checkpoints: List of model checkpoints created during training
"""
job_uuid: str
checkpoints: list[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str | None = Field(
default=None,
description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...