diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 63a764725..ea6ea0357 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -27,6 +27,7 @@ class Api(Enum): telemetry = "telemetry" models = "models" + post_training_models = "post_training_models" shields = "shields" vector_dbs = "vector_dbs" datasets = "datasets" diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index b196c8a17..f38c705c1 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.training_types import Checkpoint +from llama_stack.apis.models import Model from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -168,7 +169,13 @@ class PostTrainingJobArtifactsResponse(BaseModel): # TODO(ashwin): metrics, evals +class ModelStore(Protocol): + async def get_model(self, identifier: str) -> Model: ... + + class PostTraining(Protocol): + model_store: ModelStore | None = None + @webmethod(route="/post-training/supervised-fine-tune", method="POST") async def supervised_fine_tune( self, diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index b860d15ab..8322ad9b9 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]: routing_table_api=Api.models, router_api=Api.inference, ), + AutoRoutedApiInfo( + routing_table_api=Api.post_training_models, + router_api=Api.post_training, + ), AutoRoutedApiInfo( routing_table_api=Api.shields, router_api=Api.safety, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b7c7cb87f..4f00e786f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -67,6 +67,7 @@ def api_protocol_map() -> dict[Api, Any]: Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, + Api.post_training_models: Models, Api.safety: Safety, Api.shields: Shields, Api.telemetry: Telemetry, @@ -93,6 +94,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), + Api.post_training: (ModelsProtocolPrivate, Models, Api.post_training_models), Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), @@ -251,6 +253,8 @@ async def instantiate_providers( """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} + + # First pass: instantiate all providers for api_str, provider in sorted_providers: deps = {a: impls[a] for a in provider.spec.api_dependencies} for a in provider.spec.optional_api_dependencies: @@ -269,6 +273,10 @@ async def instantiate_providers( api = Api(api_str) impls[api] = impl + # Second pass: connect routing tables + if Api.models in impls and Api.post_training_models in impls: + impls[Api.models].post_training_models_table = impls[Api.post_training_models] + return impls diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 1358d5812..7d5ed21f4 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -21,7 +21,8 @@ async def get_routing_table_impl( ) -> Any: from ..routing_tables.benchmarks import BenchmarksRoutingTable from ..routing_tables.datasets import DatasetsRoutingTable - from ..routing_tables.models import ModelsRoutingTable + from ..routing_tables.models import InferenceModelsRoutingTable + from ..routing_tables.post_training_models import PostTrainingModelsRoutingTable from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable from ..routing_tables.shields import ShieldsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable @@ -29,7 +30,8 @@ async def get_routing_table_impl( api_to_tables = { "vector_dbs": VectorDBsRoutingTable, - "models": ModelsRoutingTable, + "models": InferenceModelsRoutingTable, + "post_training_models": PostTrainingModelsRoutingTable, "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, @@ -40,7 +42,12 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) + # For post-training API, we want to use the post-training models routing table + if api == Api.post_training: + impl = PostTrainingModelsRoutingTable(impls_by_provider_id, dist_registry) + else: + impl = api_to_tables[api.value](impls_by_provider_id, dist_registry) + await impl.initialize() return impl @@ -51,6 +58,7 @@ async def get_auto_router_impl( from .datasets import DatasetIORouter from .eval_scoring import EvalRouter, ScoringRouter from .inference import InferenceRouter + from .post_training import PostTrainingRouter from .safety import SafetyRouter from .tool_runtime import ToolRuntimeRouter from .vector_io import VectorIORouter @@ -63,6 +71,7 @@ async def get_auto_router_impl( "scoring": ScoringRouter, "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, + "post_training": PostTrainingRouter, } api_to_deps = { "inference": {"telemetry": Api.telemetry}, diff --git a/llama_stack/distribution/routers/post_training.py b/llama_stack/distribution/routers/post_training.py new file mode 100644 index 000000000..05062c227 --- /dev/null +++ b/llama_stack/distribution/routers/post_training.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.models import Model +from llama_stack.apis.post_training import ( + AlgorithmConfig, + DPOAlignmentConfig, + ListPostTrainingJobsResponse, + PostTraining, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class PostTrainingRouter(PostTraining): + """Routes to an provider based on the model""" + + async def initialize(self) -> None: + pass + + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing InferenceRouter") + self.routing_table = routing_table + + async def supervised_fine_tune( + self, + job_uuid: str, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + model: str, + checkpoint_dir: str | None = None, + algorithm_config: AlgorithmConfig | None = None, + ) -> PostTrainingJob: + provider = self.routing_table.get_provider_impl(model) + params = dict( + job_uuid=job_uuid, + training_config=training_config, + hyperparam_search_config=hyperparam_search_config, + logger_config=logger_config, + model=model, + checkpoint_dir=checkpoint_dir, + algorithm_config=algorithm_config, + ) + return provider.supervised_fine_tune(**params) + + async def register_model(self, model: Model) -> Model: + try: + # get static list of models + model = await self.register_helper.register_model(model) + except ValueError: + # if model is NOT in the list, its probably ok, but warn the user. + # + logger.warning( + f"Model {model.identifier} is not in the model registry for this provider, there might be unexpected issues." + ) + if model.provider_resource_id is None: + raise ValueError("Model provider_resource_id cannot be None") + provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) + if provider_resource_id is None: + provider_resource_id = model.provider_resource_id + model.provider_resource_id = provider_resource_id + + return model + + async def preference_optimize( + self, + job_uuid: str, + finetuned_model: str, + algorithm_config: DPOAlignmentConfig, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + ) -> PostTrainingJob: + pass + + async def get_training_jobs(self) -> ListPostTrainingJobsResponse: + pass + + async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: + pass + + async def cancel_training_job(self, job_uuid: str) -> None: + pass + + async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: + pass diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 8ec87ca50..961066ff6 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -33,7 +33,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable assert obj.provider_id != "remote", "Remote provider should not be registered" - if api == Api.inference: + if api == Api.inference or api == Api.post_training: return await p.register_model(obj) elif api == Api.safety: return await p.register_shield(obj) @@ -55,7 +55,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) if api == Api.vector_io: return await p.unregister_vector_db(obj.identifier) - elif api == Api.inference: + elif api == Api.inference or api == Api.post_training: return await p.unregister_model(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) @@ -89,11 +89,18 @@ class CommonRoutingTableImpl(RoutingTable): obj = cls(**model_data) await self.dist_registry.register(obj) + # Import routing table classes here to avoid circular imports + from .models import InferenceModelsRoutingTable + from .post_training_models import PostTrainingModelsRoutingTable + # Register all objects from providers for pid, p in self.impls_by_provider_id.items(): api = get_impl_api(p) - if api == Api.inference: - p.model_store = self + if api == Api.inference or api == Api.post_training: + # For models, we need to handle both inference and post-training providers + if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable): + # Set the model store for both types of providers + p.model_store = self elif api == Api.safety: p.shield_store = self elif api == Api.vector_io: @@ -116,15 +123,16 @@ class CommonRoutingTableImpl(RoutingTable): def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable - from .models import ModelsRoutingTable + from .models import InferenceModelsRoutingTable + from .post_training_models import PostTrainingModelsRoutingTable from .scoring_functions import ScoringFunctionsRoutingTable from .shields import ShieldsRoutingTable from .toolgroups import ToolGroupsRoutingTable from .vector_dbs import VectorDBsRoutingTable def apiname_object(): - if isinstance(self, ModelsRoutingTable): - return ("Inference", "model") + if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable): + return ("Models", "model") elif isinstance(self, ShieldsRoutingTable): return ("Safety", "shield") elif isinstance(self, VectorDBsRoutingTable): @@ -155,7 +163,25 @@ class CommonRoutingTableImpl(RoutingTable): ) if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] + provider = self.impls_by_provider_id[obj.provider_id] + # Check if the provider supports the requested API + if not hasattr(provider, "__provider_spec__"): + return provider + api = provider.__provider_spec__.api + + # Only check API compatibility for model routing tables + if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable): + if api not in [Api.inference, Api.post_training]: + raise ValueError(f"Provider {obj.provider_id} does not support the requested API") + # If we have both inference and post-training providers, prefer inference for model registration + if api == Api.post_training and Api.inference in [ + p.__provider_spec__.api for p in self.impls_by_provider_id.values() + ]: + # Try to find an inference provider first + for _, p in self.impls_by_provider_id.items(): + if hasattr(p, "__provider_spec__") and p.__provider_spec__.api == Api.inference: + return p + return provider raise ValueError(f"Provider not found for `{routing_key}`") @@ -198,7 +224,6 @@ class CommonRoutingTableImpl(RoutingTable): if obj.type == ResourceType.model.value: await self.dist_registry.register(registered_obj) return registered_obj - else: await self.dist_registry.register(obj) return obj diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 7216d9935..4ed96d416 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -8,9 +8,8 @@ import time from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel -from llama_stack.distribution.datatypes import ( - ModelWithACL, -) +from llama_stack.distribution.datatypes import ModelWithACL +from llama_stack.distribution.store import DistributionRegistry from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -18,12 +17,37 @@ from .common import CommonRoutingTableImpl logger = get_logger(name=__name__, category="core") -class ModelsRoutingTable(CommonRoutingTableImpl, Models): +class InferenceModelsRoutingTable(CommonRoutingTableImpl, Models): + """Routing table for inference models.""" + + def __init__( + self, + impls_by_provider_id: dict[str, Any], + dist_registry: DistributionRegistry, + ) -> None: + super().__init__(impls_by_provider_id, dist_registry) + self.post_training_models_table = None + + async def initialize(self) -> None: + await super().initialize() + async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) + """List all inference models.""" + models = await self.get_all_with_type("model") + if self.post_training_models_table: + post_training_models = await self.post_training_models_table.get_all_with_type("model") + # Create a set of existing model identifiers to avoid duplicates + existing_ids = {model.identifier for model in models} + # Only add models that don't already exist + models.extend([model for model in post_training_models if model.identifier not in existing_ids]) + return ListModelsResponse(data=models) async def openai_list_models(self) -> OpenAIListModelsResponse: + """List all inference models in OpenAI format.""" models = await self.get_all_with_type("model") + if self.post_training_models_table: + post_training_models = await self.post_training_models_table.get_all_with_type("model") + models.extend(post_training_models) openai_models = [ OpenAIModel( id=model.identifier, @@ -36,7 +60,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return OpenAIListModelsResponse(data=openai_models) async def get_model(self, model_id: str) -> Model: + """Get an inference model by ID.""" model = await self.get_object_by_identifier("model", model_id) + if model is None and self.post_training_models_table: + model = await self.post_training_models_table.get_object_by_identifier("model", model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") return model @@ -49,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, ) -> Model: + """Register an inference model with the routing table.""" if provider_model_id is None: provider_model_id = model_id if provider_id is None: @@ -65,6 +93,25 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") + + # Check if the provider exists in either routing table + if provider_id not in self.impls_by_provider_id: + if self.post_training_models_table and provider_id in self.post_training_models_table.impls_by_provider_id: + # If provider exists in post-training table, use that instead + return await self.post_training_models_table.register_model( + model_id=model_id, + provider_model_id=provider_model_id, + provider_id=provider_id, + metadata=metadata, + model_type=model_type, + ) + else: + # Get all available providers from both tables + available_providers = list(self.impls_by_provider_id.keys()) + if self.post_training_models_table: + available_providers.extend(self.post_training_models_table.impls_by_provider_id.keys()) + raise ValueError(f"Provider `{provider_id}` not found. Available providers: {available_providers}") + model = ModelWithACL( identifier=model_id, provider_resource_id=provider_model_id, @@ -76,7 +123,14 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return registered_model async def unregister_model(self, model_id: str) -> None: - existing_model = await self.get_model(model_id) - if existing_model is None: - raise ValueError(f"Model {model_id} not found") - await self.unregister_object(existing_model) + """Unregister an inference model from the routing table.""" + try: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.unregister_object(existing_model) + except ValueError: + if self.post_training_models_table: + await self.post_training_models_table.unregister_model(model_id) + else: + raise diff --git a/llama_stack/distribution/routing_tables/post_training_models.py b/llama_stack/distribution/routing_tables/post_training_models.py new file mode 100644 index 000000000..b34aa7f10 --- /dev/null +++ b/llama_stack/distribution/routing_tables/post_training_models.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from typing import Any + +from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel +from llama_stack.distribution.datatypes import ModelWithACL +from llama_stack.distribution.store import DistributionRegistry +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class PostTrainingModelsRoutingTable(CommonRoutingTableImpl, Models): + """Routing table for post-training models.""" + + def __init__( + self, + impls_by_provider_id: dict[str, Any], + dist_registry: DistributionRegistry, + ) -> None: + super().__init__(impls_by_provider_id, dist_registry) + + async def initialize(self) -> None: + await super().initialize() + + async def list_models(self) -> ListModelsResponse: + """List all post-training models.""" + models = await self.get_all_with_type("model") + return ListModelsResponse(data=models) + + async def openai_list_models(self) -> OpenAIListModelsResponse: + """List all post-training models in OpenAI format.""" + models = await self.get_all_with_type("model") + openai_models = [ + OpenAIModel( + id=model.identifier, + object="model", + created=int(time.time()), + owned_by="llama_stack", + ) + for model in models + ] + return OpenAIListModelsResponse(data=openai_models) + + async def get_model(self, model_id: str) -> Model: + """Get a post-training model by ID.""" + model = await self.get_object_by_identifier("model", model_id) + if model is None: + raise ValueError(f"Post-training model '{model_id}' not found") + return model + + async def register_model( + self, + model_id: str, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, + ) -> Model: + """Register a post-training model with the routing table.""" + if provider_model_id is None: + provider_model_id = model_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this model + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + ) + if metadata is None: + metadata = {} + if model_type is None: + model_type = ModelType.llm + if "embedding_dimension" not in metadata and model_type == ModelType.embedding: + raise ValueError("Embedding model must have an embedding dimension in its metadata") + model = ModelWithACL( + identifier=model_id, + provider_resource_id=provider_model_id, + provider_id=provider_id, + metadata=metadata, + model_type=model_type, + ) + registered_model = await self.register_object(model) + return registered_model + + async def unregister_model(self, model_id: str) -> None: + """Unregister a post-training model from the routing table.""" + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Post-training model {model_id} not found") + await self.unregister_object(existing_model) diff --git a/llama_stack/providers/inline/post_training/huggingface/models.py b/llama_stack/providers/inline/post_training/huggingface/models.py new file mode 100644 index 000000000..effe75814 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/models.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +model_entries = [ + ProviderModelEntry( + provider_model_id="ibm-granite/granite-3.3-8b-instruct", + aliases=["ibm-granite/granite-3.3-8b-instruct"], + model_type=ModelType.llm, + ), + ProviderModelEntry( + provider_model_id="ibm-granite/granite-3.3-8b-instruct", + aliases=["ibm-granite/granite-3.3-8b-instruct"], + model_type=ModelType.llm, + ), +] diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py index 0b2760792..fa7e4215d 100644 --- a/llama_stack/providers/inline/post_training/huggingface/post_training.py +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -8,27 +8,35 @@ from typing import Any from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets +from llama_stack.apis.models import Model from llama_stack.apis.post_training import ( AlgorithmConfig, Checkpoint, DPOAlignmentConfig, JobStatus, ListPostTrainingJobsResponse, + PostTraining, PostTrainingJob, PostTrainingJobArtifactsResponse, PostTrainingJobStatusResponse, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.huggingface.config import ( HuggingFacePostTrainingConfig, ) from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( HFFinetuningSingleDevice, ) +from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, +) from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus from llama_stack.schema_utils import webmethod +from .models import model_entries + class TrainingArtifactType(Enum): CHECKPOINT = "checkpoint" @@ -37,14 +45,17 @@ class TrainingArtifactType(Enum): _JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" +logger = get_logger(name=__name__, category="post_training") -class HuggingFacePostTrainingImpl: + +class HuggingFacePostTrainingImpl(PostTraining): def __init__( self, config: HuggingFacePostTrainingConfig, datasetio_api: DatasetIO, datasets: Datasets, ) -> None: + self.register_helper = ModelRegistryHelper(model_entries) self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets @@ -80,6 +91,10 @@ class HuggingFacePostTrainingImpl: checkpoint_dir: str | None = None, algorithm_config: AlgorithmConfig | None = None, ) -> PostTrainingJob: + model = await self._get_model(model) + if model.provider_resource_id is None: + raise ValueError(f"Model {model} has no provider_resource_id set") + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): on_log_message_cb("Starting HF finetuning") @@ -90,7 +105,7 @@ class HuggingFacePostTrainingImpl: ) resources_allocated, checkpoints = await recipe.train( - model=model, + model=model.identifier, output_dir=checkpoint_dir, job_uuid=job_uuid, lora_config=algorithm_config, @@ -110,6 +125,30 @@ class HuggingFacePostTrainingImpl: job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) return PostTrainingJob(job_uuid=job_uuid) + async def register_model(self, model: Model) -> Model: + try: + # get static list of models + model = await self.register_helper.register_model(model) + except ValueError: + # if model is NOT in the list, its probably ok, but warn the user. + # + logger.warning( + f"Model {model.identifier} is not in the model registry for this provider, there might be unexpected issues." + ) + if model.provider_resource_id is None: + raise ValueError("Model provider_resource_id cannot be None") + provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) + if provider_resource_id is None: + provider_resource_id = model.provider_resource_id + model.provider_resource_id = provider_resource_id + + return model + + async def _get_model(self, model_id: str) -> Model: + if not self.model_store: + raise ValueError("Model store not set") + return await self.model_store.get_model(model_id) + async def preference_optimize( self, job_uuid: str,