mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 01:26:10 +00:00
Merge 71caa271ad
into 76dcf47320
This commit is contained in:
commit
4a7bdf1b87
11 changed files with 393 additions and 23 deletions
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
model_entries = [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="ibm-granite/granite-3.3-8b-instruct",
|
||||
aliases=["ibm-granite/granite-3.3-8b-instruct"],
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="ibm-granite/granite-3.3-8b-instruct",
|
||||
aliases=["ibm-granite/granite-3.3-8b-instruct"],
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
|
@ -8,27 +8,35 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
PostTraining,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
from .models import model_entries
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
|
@ -37,14 +45,17 @@ class TrainingArtifactType(Enum):
|
|||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
class HuggingFacePostTrainingImpl:
|
||||
|
||||
class HuggingFacePostTrainingImpl(PostTraining):
|
||||
def __init__(
|
||||
self,
|
||||
config: HuggingFacePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.register_helper = ModelRegistryHelper(model_entries)
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
|
@ -80,6 +91,10 @@ class HuggingFacePostTrainingImpl:
|
|||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
model = await self._get_model(model)
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
|
@ -90,7 +105,7 @@ class HuggingFacePostTrainingImpl:
|
|||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=model,
|
||||
model=model.identifier,
|
||||
output_dir=checkpoint_dir,
|
||||
job_uuid=job_uuid,
|
||||
lora_config=algorithm_config,
|
||||
|
@ -110,6 +125,30 @@ class HuggingFacePostTrainingImpl:
|
|||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
# get static list of models
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
# if model is NOT in the list, its probably ok, but warn the user.
|
||||
#
|
||||
logger.warning(
|
||||
f"Model {model.identifier} is not in the model registry for this provider, there might be unexpected issues."
|
||||
)
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError("Model provider_resource_id cannot be None")
|
||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id is None:
|
||||
provider_resource_id = model.provider_resource_id
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue