mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Merge 71caa271ad
into 76dcf47320
This commit is contained in:
commit
4a7bdf1b87
11 changed files with 393 additions and 23 deletions
|
@ -27,6 +27,7 @@ class Api(Enum):
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
|
||||||
models = "models"
|
models = "models"
|
||||||
|
post_training_models = "post_training_models"
|
||||||
shields = "shields"
|
shields = "shields"
|
||||||
vector_dbs = "vector_dbs"
|
vector_dbs = "vector_dbs"
|
||||||
datasets = "datasets"
|
datasets = "datasets"
|
||||||
|
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,7 +169,13 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
# TODO(ashwin): metrics, evals
|
# TODO(ashwin): metrics, evals
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStore(Protocol):
|
||||||
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
class PostTraining(Protocol):
|
||||||
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||||
routing_table_api=Api.models,
|
routing_table_api=Api.models,
|
||||||
router_api=Api.inference,
|
router_api=Api.inference,
|
||||||
),
|
),
|
||||||
|
AutoRoutedApiInfo(
|
||||||
|
routing_table_api=Api.post_training_models,
|
||||||
|
router_api=Api.post_training,
|
||||||
|
),
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.shields,
|
routing_table_api=Api.shields,
|
||||||
router_api=Api.safety,
|
router_api=Api.safety,
|
||||||
|
|
|
@ -67,6 +67,7 @@ def api_protocol_map() -> dict[Api, Any]:
|
||||||
Api.vector_io: VectorIO,
|
Api.vector_io: VectorIO,
|
||||||
Api.vector_dbs: VectorDBs,
|
Api.vector_dbs: VectorDBs,
|
||||||
Api.models: Models,
|
Api.models: Models,
|
||||||
|
Api.post_training_models: Models,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.shields: Shields,
|
Api.shields: Shields,
|
||||||
Api.telemetry: Telemetry,
|
Api.telemetry: Telemetry,
|
||||||
|
@ -93,6 +94,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
|
Api.post_training: (ModelsProtocolPrivate, Models, Api.post_training_models),
|
||||||
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
|
@ -251,6 +253,8 @@ async def instantiate_providers(
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
|
|
||||||
|
# First pass: instantiate all providers
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
for a in provider.spec.optional_api_dependencies:
|
for a in provider.spec.optional_api_dependencies:
|
||||||
|
@ -269,6 +273,10 @@ async def instantiate_providers(
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
|
# Second pass: connect routing tables
|
||||||
|
if Api.models in impls and Api.post_training_models in impls:
|
||||||
|
impls[Api.models].post_training_models_table = impls[Api.post_training_models]
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,8 @@ async def get_routing_table_impl(
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||||
from ..routing_tables.models import ModelsRoutingTable
|
from ..routing_tables.models import InferenceModelsRoutingTable
|
||||||
|
from ..routing_tables.post_training_models import PostTrainingModelsRoutingTable
|
||||||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||||
from ..routing_tables.shields import ShieldsRoutingTable
|
from ..routing_tables.shields import ShieldsRoutingTable
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
@ -29,7 +30,8 @@ async def get_routing_table_impl(
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"vector_dbs": VectorDBsRoutingTable,
|
"vector_dbs": VectorDBsRoutingTable,
|
||||||
"models": ModelsRoutingTable,
|
"models": InferenceModelsRoutingTable,
|
||||||
|
"post_training_models": PostTrainingModelsRoutingTable,
|
||||||
"shields": ShieldsRoutingTable,
|
"shields": ShieldsRoutingTable,
|
||||||
"datasets": DatasetsRoutingTable,
|
"datasets": DatasetsRoutingTable,
|
||||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||||
|
@ -40,7 +42,12 @@ async def get_routing_table_impl(
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
# For post-training API, we want to use the post-training models routing table
|
||||||
|
if api == Api.post_training:
|
||||||
|
impl = PostTrainingModelsRoutingTable(impls_by_provider_id, dist_registry)
|
||||||
|
else:
|
||||||
|
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||||
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
@ -51,6 +58,7 @@ async def get_auto_router_impl(
|
||||||
from .datasets import DatasetIORouter
|
from .datasets import DatasetIORouter
|
||||||
from .eval_scoring import EvalRouter, ScoringRouter
|
from .eval_scoring import EvalRouter, ScoringRouter
|
||||||
from .inference import InferenceRouter
|
from .inference import InferenceRouter
|
||||||
|
from .post_training import PostTrainingRouter
|
||||||
from .safety import SafetyRouter
|
from .safety import SafetyRouter
|
||||||
from .tool_runtime import ToolRuntimeRouter
|
from .tool_runtime import ToolRuntimeRouter
|
||||||
from .vector_io import VectorIORouter
|
from .vector_io import VectorIORouter
|
||||||
|
@ -63,6 +71,7 @@ async def get_auto_router_impl(
|
||||||
"scoring": ScoringRouter,
|
"scoring": ScoringRouter,
|
||||||
"eval": EvalRouter,
|
"eval": EvalRouter,
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
|
"post_training": PostTrainingRouter,
|
||||||
}
|
}
|
||||||
api_to_deps = {
|
api_to_deps = {
|
||||||
"inference": {"telemetry": Api.telemetry},
|
"inference": {"telemetry": Api.telemetry},
|
||||||
|
|
101
llama_stack/distribution/routers/post_training.py
Normal file
101
llama_stack/distribution/routers/post_training.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.apis.post_training import (
|
||||||
|
AlgorithmConfig,
|
||||||
|
DPOAlignmentConfig,
|
||||||
|
ListPostTrainingJobsResponse,
|
||||||
|
PostTraining,
|
||||||
|
PostTrainingJob,
|
||||||
|
PostTrainingJobArtifactsResponse,
|
||||||
|
PostTrainingJobStatusResponse,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class PostTrainingRouter(PostTraining):
|
||||||
|
"""Routes to an provider based on the model"""
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing InferenceRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def supervised_fine_tune(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: dict[str, Any],
|
||||||
|
logger_config: dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
checkpoint_dir: str | None = None,
|
||||||
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
|
params = dict(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
training_config=training_config,
|
||||||
|
hyperparam_search_config=hyperparam_search_config,
|
||||||
|
logger_config=logger_config,
|
||||||
|
model=model,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
)
|
||||||
|
return provider.supervised_fine_tune(**params)
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
try:
|
||||||
|
# get static list of models
|
||||||
|
model = await self.register_helper.register_model(model)
|
||||||
|
except ValueError:
|
||||||
|
# if model is NOT in the list, its probably ok, but warn the user.
|
||||||
|
#
|
||||||
|
logger.warning(
|
||||||
|
f"Model {model.identifier} is not in the model registry for this provider, there might be unexpected issues."
|
||||||
|
)
|
||||||
|
if model.provider_resource_id is None:
|
||||||
|
raise ValueError("Model provider_resource_id cannot be None")
|
||||||
|
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||||
|
if provider_resource_id is None:
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
model.provider_resource_id = provider_resource_id
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def preference_optimize(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
finetuned_model: str,
|
||||||
|
algorithm_config: DPOAlignmentConfig,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: dict[str, Any],
|
||||||
|
logger_config: dict[str, Any],
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||||
|
pass
|
|
@ -33,7 +33,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
|
|
||||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||||
|
|
||||||
if api == Api.inference:
|
if api == Api.inference or api == Api.post_training:
|
||||||
return await p.register_model(obj)
|
return await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
return await p.register_shield(obj)
|
return await p.register_shield(obj)
|
||||||
|
@ -55,7 +55,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
if api == Api.vector_io:
|
if api == Api.vector_io:
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
elif api == Api.inference:
|
elif api == Api.inference or api == Api.post_training:
|
||||||
return await p.unregister_model(obj.identifier)
|
return await p.unregister_model(obj.identifier)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.unregister_dataset(obj.identifier)
|
return await p.unregister_dataset(obj.identifier)
|
||||||
|
@ -89,11 +89,18 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
obj = cls(**model_data)
|
obj = cls(**model_data)
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
# Import routing table classes here to avoid circular imports
|
||||||
|
from .models import InferenceModelsRoutingTable
|
||||||
|
from .post_training_models import PostTrainingModelsRoutingTable
|
||||||
|
|
||||||
# Register all objects from providers
|
# Register all objects from providers
|
||||||
for pid, p in self.impls_by_provider_id.items():
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
if api == Api.inference:
|
if api == Api.inference or api == Api.post_training:
|
||||||
p.model_store = self
|
# For models, we need to handle both inference and post-training providers
|
||||||
|
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||||
|
# Set the model store for both types of providers
|
||||||
|
p.model_store = self
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
elif api == Api.vector_io:
|
elif api == Api.vector_io:
|
||||||
|
@ -116,15 +123,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
from .benchmarks import BenchmarksRoutingTable
|
from .benchmarks import BenchmarksRoutingTable
|
||||||
from .datasets import DatasetsRoutingTable
|
from .datasets import DatasetsRoutingTable
|
||||||
from .models import ModelsRoutingTable
|
from .models import InferenceModelsRoutingTable
|
||||||
|
from .post_training_models import PostTrainingModelsRoutingTable
|
||||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||||
from .shields import ShieldsRoutingTable
|
from .shields import ShieldsRoutingTable
|
||||||
from .toolgroups import ToolGroupsRoutingTable
|
from .toolgroups import ToolGroupsRoutingTable
|
||||||
from .vector_dbs import VectorDBsRoutingTable
|
from .vector_dbs import VectorDBsRoutingTable
|
||||||
|
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Models", "model")
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
return ("Safety", "shield")
|
return ("Safety", "shield")
|
||||||
elif isinstance(self, VectorDBsRoutingTable):
|
elif isinstance(self, VectorDBsRoutingTable):
|
||||||
|
@ -155,7 +163,25 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not provider_id or provider_id == obj.provider_id:
|
if not provider_id or provider_id == obj.provider_id:
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
provider = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
# Check if the provider supports the requested API
|
||||||
|
if not hasattr(provider, "__provider_spec__"):
|
||||||
|
return provider
|
||||||
|
api = provider.__provider_spec__.api
|
||||||
|
|
||||||
|
# Only check API compatibility for model routing tables
|
||||||
|
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||||
|
if api not in [Api.inference, Api.post_training]:
|
||||||
|
raise ValueError(f"Provider {obj.provider_id} does not support the requested API")
|
||||||
|
# If we have both inference and post-training providers, prefer inference for model registration
|
||||||
|
if api == Api.post_training and Api.inference in [
|
||||||
|
p.__provider_spec__.api for p in self.impls_by_provider_id.values()
|
||||||
|
]:
|
||||||
|
# Try to find an inference provider first
|
||||||
|
for _, p in self.impls_by_provider_id.items():
|
||||||
|
if hasattr(p, "__provider_spec__") and p.__provider_spec__.api == Api.inference:
|
||||||
|
return p
|
||||||
|
return provider
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
|
@ -198,7 +224,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if obj.type == ResourceType.model.value:
|
if obj.type == ResourceType.model.value:
|
||||||
await self.dist_registry.register(registered_obj)
|
await self.dist_registry.register(registered_obj)
|
||||||
return registered_obj
|
return registered_obj
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -8,9 +8,8 @@ import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import ModelWithACL
|
||||||
ModelWithACL,
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -18,12 +17,37 @@ from .common import CommonRoutingTableImpl
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class InferenceModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
"""Routing table for inference models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
impls_by_provider_id: dict[str, Any],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(impls_by_provider_id, dist_registry)
|
||||||
|
self.post_training_models_table = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
await super().initialize()
|
||||||
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
"""List all inference models."""
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
if self.post_training_models_table:
|
||||||
|
post_training_models = await self.post_training_models_table.get_all_with_type("model")
|
||||||
|
# Create a set of existing model identifiers to avoid duplicates
|
||||||
|
existing_ids = {model.identifier for model in models}
|
||||||
|
# Only add models that don't already exist
|
||||||
|
models.extend([model for model in post_training_models if model.identifier not in existing_ids])
|
||||||
|
return ListModelsResponse(data=models)
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
"""List all inference models in OpenAI format."""
|
||||||
models = await self.get_all_with_type("model")
|
models = await self.get_all_with_type("model")
|
||||||
|
if self.post_training_models_table:
|
||||||
|
post_training_models = await self.post_training_models_table.get_all_with_type("model")
|
||||||
|
models.extend(post_training_models)
|
||||||
openai_models = [
|
openai_models = [
|
||||||
OpenAIModel(
|
OpenAIModel(
|
||||||
id=model.identifier,
|
id=model.identifier,
|
||||||
|
@ -36,7 +60,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
async def get_model(self, model_id: str) -> Model:
|
||||||
|
"""Get an inference model by ID."""
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
|
if model is None and self.post_training_models_table:
|
||||||
|
model = await self.post_training_models_table.get_object_by_identifier("model", model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
return model
|
return model
|
||||||
|
@ -49,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: ModelType | None = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
"""Register an inference model with the routing table."""
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
|
@ -65,6 +93,25 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
|
||||||
|
# Check if the provider exists in either routing table
|
||||||
|
if provider_id not in self.impls_by_provider_id:
|
||||||
|
if self.post_training_models_table and provider_id in self.post_training_models_table.impls_by_provider_id:
|
||||||
|
# If provider exists in post-training table, use that instead
|
||||||
|
return await self.post_training_models_table.register_model(
|
||||||
|
model_id=model_id,
|
||||||
|
provider_model_id=provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Get all available providers from both tables
|
||||||
|
available_providers = list(self.impls_by_provider_id.keys())
|
||||||
|
if self.post_training_models_table:
|
||||||
|
available_providers.extend(self.post_training_models_table.impls_by_provider_id.keys())
|
||||||
|
raise ValueError(f"Provider `{provider_id}` not found. Available providers: {available_providers}")
|
||||||
|
|
||||||
model = ModelWithACL(
|
model = ModelWithACL(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
|
@ -76,7 +123,14 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return registered_model
|
return registered_model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
existing_model = await self.get_model(model_id)
|
"""Unregister an inference model from the routing table."""
|
||||||
if existing_model is None:
|
try:
|
||||||
raise ValueError(f"Model {model_id} not found")
|
existing_model = await self.get_model(model_id)
|
||||||
await self.unregister_object(existing_model)
|
if existing_model is None:
|
||||||
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
await self.unregister_object(existing_model)
|
||||||
|
except ValueError:
|
||||||
|
if self.post_training_models_table:
|
||||||
|
await self.post_training_models_table.unregister_model(model_id)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
|
from llama_stack.distribution.datatypes import ModelWithACL
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class PostTrainingModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
"""Routing table for post-training models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
impls_by_provider_id: dict[str, Any],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(impls_by_provider_id, dist_registry)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
await super().initialize()
|
||||||
|
|
||||||
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
"""List all post-training models."""
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
return ListModelsResponse(data=models)
|
||||||
|
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
"""List all post-training models in OpenAI format."""
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
openai_models = [
|
||||||
|
OpenAIModel(
|
||||||
|
id=model.identifier,
|
||||||
|
object="model",
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="llama_stack",
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
|
async def get_model(self, model_id: str) -> Model:
|
||||||
|
"""Get a post-training model by ID."""
|
||||||
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Post-training model '{model_id}' not found")
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def register_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
provider_model_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
model_type: ModelType | None = None,
|
||||||
|
) -> Model:
|
||||||
|
"""Register a post-training model with the routing table."""
|
||||||
|
if provider_model_id is None:
|
||||||
|
provider_model_id = model_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
|
)
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_resource_id=provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
registered_model = await self.register_object(model)
|
||||||
|
return registered_model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
"""Unregister a post-training model from the routing table."""
|
||||||
|
existing_model = await self.get_model(model_id)
|
||||||
|
if existing_model is None:
|
||||||
|
raise ValueError(f"Post-training model {model_id} not found")
|
||||||
|
await self.unregister_object(existing_model)
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ProviderModelEntry,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_entries = [
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="ibm-granite/granite-3.3-8b-instruct",
|
||||||
|
aliases=["ibm-granite/granite-3.3-8b-instruct"],
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="ibm-granite/granite-3.3-8b-instruct",
|
||||||
|
aliases=["ibm-granite/granite-3.3-8b-instruct"],
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
]
|
|
@ -8,27 +8,35 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
AlgorithmConfig,
|
AlgorithmConfig,
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
ListPostTrainingJobsResponse,
|
ListPostTrainingJobsResponse,
|
||||||
|
PostTraining,
|
||||||
PostTrainingJob,
|
PostTrainingJob,
|
||||||
PostTrainingJobArtifactsResponse,
|
PostTrainingJobArtifactsResponse,
|
||||||
PostTrainingJobStatusResponse,
|
PostTrainingJobStatusResponse,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||||
HuggingFacePostTrainingConfig,
|
HuggingFacePostTrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||||
HFFinetuningSingleDevice,
|
HFFinetuningSingleDevice,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
from .models import model_entries
|
||||||
|
|
||||||
|
|
||||||
class TrainingArtifactType(Enum):
|
class TrainingArtifactType(Enum):
|
||||||
CHECKPOINT = "checkpoint"
|
CHECKPOINT = "checkpoint"
|
||||||
|
@ -37,14 +45,17 @@ class TrainingArtifactType(Enum):
|
||||||
|
|
||||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
class HuggingFacePostTrainingImpl:
|
|
||||||
|
class HuggingFacePostTrainingImpl(PostTraining):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: HuggingFacePostTrainingConfig,
|
config: HuggingFacePostTrainingConfig,
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets: Datasets,
|
datasets: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.register_helper = ModelRegistryHelper(model_entries)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets
|
self.datasets_api = datasets
|
||||||
|
@ -80,6 +91,10 @@ class HuggingFacePostTrainingImpl:
|
||||||
checkpoint_dir: str | None = None,
|
checkpoint_dir: str | None = None,
|
||||||
algorithm_config: AlgorithmConfig | None = None,
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
|
model = await self._get_model(model)
|
||||||
|
if model.provider_resource_id is None:
|
||||||
|
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||||
|
|
||||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||||
on_log_message_cb("Starting HF finetuning")
|
on_log_message_cb("Starting HF finetuning")
|
||||||
|
|
||||||
|
@ -90,7 +105,7 @@ class HuggingFacePostTrainingImpl:
|
||||||
)
|
)
|
||||||
|
|
||||||
resources_allocated, checkpoints = await recipe.train(
|
resources_allocated, checkpoints = await recipe.train(
|
||||||
model=model,
|
model=model.identifier,
|
||||||
output_dir=checkpoint_dir,
|
output_dir=checkpoint_dir,
|
||||||
job_uuid=job_uuid,
|
job_uuid=job_uuid,
|
||||||
lora_config=algorithm_config,
|
lora_config=algorithm_config,
|
||||||
|
@ -110,6 +125,30 @@ class HuggingFacePostTrainingImpl:
|
||||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||||
return PostTrainingJob(job_uuid=job_uuid)
|
return PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
try:
|
||||||
|
# get static list of models
|
||||||
|
model = await self.register_helper.register_model(model)
|
||||||
|
except ValueError:
|
||||||
|
# if model is NOT in the list, its probably ok, but warn the user.
|
||||||
|
#
|
||||||
|
logger.warning(
|
||||||
|
f"Model {model.identifier} is not in the model registry for this provider, there might be unexpected issues."
|
||||||
|
)
|
||||||
|
if model.provider_resource_id is None:
|
||||||
|
raise ValueError("Model provider_resource_id cannot be None")
|
||||||
|
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||||
|
if provider_resource_id is None:
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
model.provider_resource_id = provider_resource_id
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store not set")
|
||||||
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue