mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-05 13:40:30 +00:00
feat: associated models API with post_training
there are likely scenarios where admins of a stack only want to allow clients to fine-tune certain models, register certain models to be fine-tuned. etc introduce the post_training router and post_training_models as the associated type. A different model type needs to be used for inference vs post_training due to the structure of the router currently. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
63a9f08c9e
commit
71caa271ad
11 changed files with 393 additions and 23 deletions
|
@ -33,7 +33,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
if api == Api.inference or api == Api.post_training:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
|
@ -55,7 +55,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
api = get_impl_api(p)
|
||||
if api == Api.vector_io:
|
||||
return await p.unregister_vector_db(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
elif api == Api.inference or api == Api.post_training:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
|
@ -89,11 +89,18 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Import routing table classes here to avoid circular imports
|
||||
from .models import InferenceModelsRoutingTable
|
||||
from .post_training_models import PostTrainingModelsRoutingTable
|
||||
|
||||
# Register all objects from providers
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
if api == Api.inference or api == Api.post_training:
|
||||
# For models, we need to handle both inference and post-training providers
|
||||
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||
# Set the model store for both types of providers
|
||||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.vector_io:
|
||||
|
@ -116,15 +123,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
from .models import ModelsRoutingTable
|
||||
from .models import InferenceModelsRoutingTable
|
||||
from .post_training_models import PostTrainingModelsRoutingTable
|
||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||
return ("Models", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
|
@ -155,7 +163,25 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
)
|
||||
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
provider = self.impls_by_provider_id[obj.provider_id]
|
||||
# Check if the provider supports the requested API
|
||||
if not hasattr(provider, "__provider_spec__"):
|
||||
return provider
|
||||
api = provider.__provider_spec__.api
|
||||
|
||||
# Only check API compatibility for model routing tables
|
||||
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||
if api not in [Api.inference, Api.post_training]:
|
||||
raise ValueError(f"Provider {obj.provider_id} does not support the requested API")
|
||||
# If we have both inference and post-training providers, prefer inference for model registration
|
||||
if api == Api.post_training and Api.inference in [
|
||||
p.__provider_spec__.api for p in self.impls_by_provider_id.values()
|
||||
]:
|
||||
# Try to find an inference provider first
|
||||
for _, p in self.impls_by_provider_id.items():
|
||||
if hasattr(p, "__provider_spec__") and p.__provider_spec__.api == Api.inference:
|
||||
return p
|
||||
return provider
|
||||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
|
@ -198,7 +224,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
|
|
@ -8,9 +8,8 @@ import time
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelWithACL,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
@ -18,12 +17,37 @@ from .common import CommonRoutingTableImpl
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
class InferenceModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
"""Routing table for inference models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
impls_by_provider_id: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> None:
|
||||
super().__init__(impls_by_provider_id, dist_registry)
|
||||
self.post_training_models_table = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
"""List all inference models."""
|
||||
models = await self.get_all_with_type("model")
|
||||
if self.post_training_models_table:
|
||||
post_training_models = await self.post_training_models_table.get_all_with_type("model")
|
||||
# Create a set of existing model identifiers to avoid duplicates
|
||||
existing_ids = {model.identifier for model in models}
|
||||
# Only add models that don't already exist
|
||||
models.extend([model for model in post_training_models if model.identifier not in existing_ids])
|
||||
return ListModelsResponse(data=models)
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List all inference models in OpenAI format."""
|
||||
models = await self.get_all_with_type("model")
|
||||
if self.post_training_models_table:
|
||||
post_training_models = await self.post_training_models_table.get_all_with_type("model")
|
||||
models.extend(post_training_models)
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
|
@ -36,7 +60,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
"""Get an inference model by ID."""
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None and self.post_training_models_table:
|
||||
model = await self.post_training_models_table.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
return model
|
||||
|
@ -49,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
"""Register an inference model with the routing table."""
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
|
@ -65,6 +93,25 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
|
||||
# Check if the provider exists in either routing table
|
||||
if provider_id not in self.impls_by_provider_id:
|
||||
if self.post_training_models_table and provider_id in self.post_training_models_table.impls_by_provider_id:
|
||||
# If provider exists in post-training table, use that instead
|
||||
return await self.post_training_models_table.register_model(
|
||||
model_id=model_id,
|
||||
provider_model_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
)
|
||||
else:
|
||||
# Get all available providers from both tables
|
||||
available_providers = list(self.impls_by_provider_id.keys())
|
||||
if self.post_training_models_table:
|
||||
available_providers.extend(self.post_training_models_table.impls_by_provider_id.keys())
|
||||
raise ValueError(f"Provider `{provider_id}` not found. Available providers: {available_providers}")
|
||||
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
|
@ -76,7 +123,14 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
"""Unregister an inference model from the routing table."""
|
||||
try:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
except ValueError:
|
||||
if self.post_training_models_table:
|
||||
await self.post_training_models_table.unregister_model(model_id)
|
||||
else:
|
||||
raise
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class PostTrainingModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
"""Routing table for post-training models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
impls_by_provider_id: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> None:
|
||||
super().__init__(impls_by_provider_id, dist_registry)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all post-training models."""
|
||||
models = await self.get_all_with_type("model")
|
||||
return ListModelsResponse(data=models)
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List all post-training models in OpenAI format."""
|
||||
models = await self.get_all_with_type("model")
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
object="model",
|
||||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
)
|
||||
for model in models
|
||||
]
|
||||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
"""Get a post-training model by ID."""
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Post-training model '{model_id}' not found")
|
||||
return model
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
"""Register a post-training model with the routing table."""
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
"""Unregister a post-training model from the routing table."""
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Post-training model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
Loading…
Add table
Add a link
Reference in a new issue