mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 04:02:25 +00:00
chore(misc): make tests and starter faster
This commit is contained in:
parent
ea46f74092
commit
2b4e88a3de
19 changed files with 2860 additions and 1660 deletions
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
import tempfile
|
||||
|
||||
|
||||
class HuggingFacePostTrainingConfig(BaseModel):
|
||||
|
|
@ -71,7 +72,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
|||
dpo_beta: float = 0.1
|
||||
use_reference_model: bool = True
|
||||
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
||||
dpo_output_dir: str = "./checkpoints/dpo"
|
||||
dpo_output_dir: str = Field(default_factory=lambda: tempfile.mkdtemp(prefix="dpo_output_"))
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -22,15 +22,8 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
||||
HFDPOAlignmentSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
|
|
@ -85,6 +78,10 @@ class HuggingFacePostTrainingImpl:
|
|||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
|
|
@ -124,6 +121,10 @@ class HuggingFacePostTrainingImpl:
|
|||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
||||
HFDPOAlignmentSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF DPO alignment")
|
||||
|
||||
recipe = HFDPOAlignmentSingleDevice(
|
||||
|
|
@ -168,7 +169,6 @@ class HuggingFacePostTrainingImpl:
|
|||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
|
|
@ -195,16 +195,13 @@ class HuggingFacePostTrainingImpl:
|
|||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
|
|
|
|||
|
|
@ -23,9 +23,6 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
|
@ -84,6 +81,10 @@ class TorchtunePostTrainingImpl:
|
|||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting Lora finetuning")
|
||||
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue