llama-stack-mirror/llama_stack/distribution/routing_tables/models.py
Charlie Doern 71caa271ad feat: associated models API with post_training
there are likely scenarios where admins of a stack only want to allow clients to fine-tune certain models, register certain models to be fine-tuned. etc
introduce the post_training router and post_training_models as the associated type. A different model type needs to be used for inference vs post_training due to the structure of the router currently.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-05-30 13:32:11 -04:00

136 lines
5.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import ModelWithACL
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class InferenceModelsRoutingTable(CommonRoutingTableImpl, Models):
"""Routing table for inference models."""
def __init__(
self,
impls_by_provider_id: dict[str, Any],
dist_registry: DistributionRegistry,
) -> None:
super().__init__(impls_by_provider_id, dist_registry)
self.post_training_models_table = None
async def initialize(self) -> None:
await super().initialize()
async def list_models(self) -> ListModelsResponse:
"""List all inference models."""
models = await self.get_all_with_type("model")
if self.post_training_models_table:
post_training_models = await self.post_training_models_table.get_all_with_type("model")
# Create a set of existing model identifiers to avoid duplicates
existing_ids = {model.identifier for model in models}
# Only add models that don't already exist
models.extend([model for model in post_training_models if model.identifier not in existing_ids])
return ListModelsResponse(data=models)
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List all inference models in OpenAI format."""
models = await self.get_all_with_type("model")
if self.post_training_models_table:
post_training_models = await self.post_training_models_table.get_all_with_type("model")
models.extend(post_training_models)
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
"""Get an inference model by ID."""
model = await self.get_object_by_identifier("model", model_id)
if model is None and self.post_training_models_table:
model = await self.post_training_models_table.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
"""Register an inference model with the routing table."""
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
# Check if the provider exists in either routing table
if provider_id not in self.impls_by_provider_id:
if self.post_training_models_table and provider_id in self.post_training_models_table.impls_by_provider_id:
# If provider exists in post-training table, use that instead
return await self.post_training_models_table.register_model(
model_id=model_id,
provider_model_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
else:
# Get all available providers from both tables
available_providers = list(self.impls_by_provider_id.keys())
if self.post_training_models_table:
available_providers.extend(self.post_training_models_table.impls_by_provider_id.keys())
raise ValueError(f"Provider `{provider_id}` not found. Available providers: {available_providers}")
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
"""Unregister an inference model from the routing table."""
try:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
except ValueError:
if self.post_training_models_table:
await self.post_training_models_table.unregister_model(model_id)
else:
raise