# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import time from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ModelWithACL from llama_stack.distribution.store import DistributionRegistry from llama_stack.log import get_logger from .common import CommonRoutingTableImpl logger = get_logger(name=__name__, category="core") class InferenceModelsRoutingTable(CommonRoutingTableImpl, Models): """Routing table for inference models.""" def __init__( self, impls_by_provider_id: dict[str, Any], dist_registry: DistributionRegistry, ) -> None: super().__init__(impls_by_provider_id, dist_registry) self.post_training_models_table = None async def initialize(self) -> None: await super().initialize() async def list_models(self) -> ListModelsResponse: """List all inference models.""" models = await self.get_all_with_type("model") if self.post_training_models_table: post_training_models = await self.post_training_models_table.get_all_with_type("model") # Create a set of existing model identifiers to avoid duplicates existing_ids = {model.identifier for model in models} # Only add models that don't already exist models.extend([model for model in post_training_models if model.identifier not in existing_ids]) return ListModelsResponse(data=models) async def openai_list_models(self) -> OpenAIListModelsResponse: """List all inference models in OpenAI format.""" models = await self.get_all_with_type("model") if self.post_training_models_table: post_training_models = await self.post_training_models_table.get_all_with_type("model") models.extend(post_training_models) openai_models = [ OpenAIModel( id=model.identifier, object="model", created=int(time.time()), owned_by="llama_stack", ) for model in models ] return OpenAIListModelsResponse(data=openai_models) async def get_model(self, model_id: str) -> Model: """Get an inference model by ID.""" model = await self.get_object_by_identifier("model", model_id) if model is None and self.post_training_models_table: model = await self.post_training_models_table.get_object_by_identifier("model", model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") return model async def register_model( self, model_id: str, provider_model_id: str | None = None, provider_id: str | None = None, metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, ) -> Model: """Register an inference model with the routing table.""" if provider_model_id is None: provider_model_id = model_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this model if len(self.impls_by_provider_id) == 1: provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" ) if metadata is None: metadata = {} if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") # Check if the provider exists in either routing table if provider_id not in self.impls_by_provider_id: if self.post_training_models_table and provider_id in self.post_training_models_table.impls_by_provider_id: # If provider exists in post-training table, use that instead return await self.post_training_models_table.register_model( model_id=model_id, provider_model_id=provider_model_id, provider_id=provider_id, metadata=metadata, model_type=model_type, ) else: # Get all available providers from both tables available_providers = list(self.impls_by_provider_id.keys()) if self.post_training_models_table: available_providers.extend(self.post_training_models_table.impls_by_provider_id.keys()) raise ValueError(f"Provider `{provider_id}` not found. Available providers: {available_providers}") model = ModelWithACL( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, model_type=model_type, ) registered_model = await self.register_object(model) return registered_model async def unregister_model(self, model_id: str) -> None: """Unregister an inference model from the routing table.""" try: existing_model = await self.get_model(model_id) if existing_model is None: raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) except ValueError: if self.post_training_models_table: await self.post_training_models_table.unregister_model(model_id) else: raise