mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 16:57:21 +00:00
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 9s
Python Package Build Test / build (3.12) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 12s
Test Llama Stack Build / generate-matrix (push) Successful in 11s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 14s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 22s
Test External API and Providers / test-external (venv) (push) Failing after 14s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 15s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 22s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 14s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Test Llama Stack Build / build-single-provider (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 18s
Unit Tests / unit-tests (3.12) (push) Failing after 16s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 18s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 16s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 18s
Test Llama Stack Build / build (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 18s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 20s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.13) (push) Failing after 53s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 59s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 1m1s
Update ReadTheDocs / update-readthedocs (push) Failing after 1m6s
Pre-commit / pre-commit (push) Successful in 1m53s
A bunch of miscellaneous cleanup focusing on tests, but ended up speeding up starter distro substantially. - Pulled llama stack client init for tests into `pytest_sessionstart` so it does not clobber output - Profiling of that told me where we were doing lots of heavy imports for starter, so lazied them - starter now starts 20seconds+ faster on my Mac - A few other smallish refactors for `compat_client`
83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class HuggingFacePostTrainingConfig(BaseModel):
|
|
# Device to run training on (cuda, cpu, mps)
|
|
device: str = "cuda"
|
|
|
|
# Distributed training backend if using multiple devices
|
|
# fsdp: Fully Sharded Data Parallel
|
|
# deepspeed: DeepSpeed ZeRO optimization
|
|
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
|
|
|
|
# Format for saving model checkpoints
|
|
# full_state: Save complete model state
|
|
# huggingface: Save in HuggingFace format (recommended for compatibility)
|
|
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
|
|
|
|
# Template for formatting chat inputs and outputs
|
|
# Used to structure the conversation format for training
|
|
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
|
|
|
|
# Model-specific configuration parameters
|
|
# trust_remote_code: Allow execution of custom model code
|
|
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
|
|
model_specific_config: dict = {
|
|
"trust_remote_code": True,
|
|
"attn_implementation": "sdpa",
|
|
}
|
|
|
|
# Maximum sequence length for training
|
|
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
|
|
# Longer sequences may cause memory issues on MPS devices
|
|
max_seq_length: int = 2048
|
|
|
|
# Enable gradient checkpointing to reduce memory usage
|
|
# Trades computation for memory by recomputing activations
|
|
gradient_checkpointing: bool = False
|
|
|
|
# Maximum number of checkpoints to keep
|
|
# Older checkpoints are deleted when this limit is reached
|
|
save_total_limit: int = 3
|
|
|
|
# Number of training steps between logging updates
|
|
logging_steps: int = 10
|
|
|
|
# Ratio of training steps used for learning rate warmup
|
|
# Helps stabilize early training
|
|
warmup_ratio: float = 0.1
|
|
|
|
# L2 regularization coefficient
|
|
# Helps prevent overfitting
|
|
weight_decay: float = 0.01
|
|
|
|
# Number of worker processes for data loading
|
|
# Higher values can improve data loading speed but increase memory usage
|
|
dataloader_num_workers: int = 4
|
|
|
|
# Whether to pin memory in data loader
|
|
# Can improve data transfer speed to GPU but uses more memory
|
|
dataloader_pin_memory: bool = True
|
|
|
|
# DPO-specific parameters
|
|
dpo_beta: float = 0.1
|
|
use_reference_model: bool = True
|
|
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
|
dpo_output_dir: str
|
|
|
|
@classmethod
|
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
|
return {
|
|
"checkpoint_format": "huggingface",
|
|
"distributed_backend": None,
|
|
"device": "cpu",
|
|
"dpo_output_dir": __distro_dir__ + "/dpo_output",
|
|
}
|