This commit is contained in:
Charlie Doern 2025-06-02 17:32:30 -04:00 committed by GitHub
commit 4a7bdf1b87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 393 additions and 23 deletions

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
model_entries = [
ProviderModelEntry(
provider_model_id="ibm-granite/granite-3.3-8b-instruct",
aliases=["ibm-granite/granite-3.3-8b-instruct"],
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="ibm-granite/granite-3.3-8b-instruct",
aliases=["ibm-granite/granite-3.3-8b-instruct"],
model_type=ModelType.llm,
),
]

View file

@ -8,27 +8,35 @@ from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.models import Model
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTraining,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
from .models import model_entries
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
@ -37,14 +45,17 @@ class TrainingArtifactType(Enum):
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
logger = get_logger(name=__name__, category="post_training")
class HuggingFacePostTrainingImpl:
class HuggingFacePostTrainingImpl(PostTraining):
def __init__(
self,
config: HuggingFacePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.register_helper = ModelRegistryHelper(model_entries)
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
@ -80,6 +91,10 @@ class HuggingFacePostTrainingImpl:
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
model = await self._get_model(model)
if model.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
@ -90,7 +105,7 @@ class HuggingFacePostTrainingImpl:
)
resources_allocated, checkpoints = await recipe.train(
model=model,
model=model.identifier,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
@ -110,6 +125,30 @@ class HuggingFacePostTrainingImpl:
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def register_model(self, model: Model) -> Model:
try:
# get static list of models
model = await self.register_helper.register_model(model)
except ValueError:
# if model is NOT in the list, its probably ok, but warn the user.
#
logger.warning(
f"Model {model.identifier} is not in the model registry for this provider, there might be unexpected issues."
)
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
model.provider_resource_id = provider_resource_id
return model
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def preference_optimize(
self,
job_uuid: str,