# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import time from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( ModelWithOwner, ) from llama_stack.log import get_logger from .common import CommonRoutingTableImpl logger = get_logger(name=__name__, category="core") class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: models = await self.get_all_with_type("model") openai_models = [ OpenAIModel( id=model.identifier, object="model", created=int(time.time()), owned_by="llama_stack", ) for model in models ] return OpenAIListModelsResponse(data=openai_models) async def get_model(self, model_id: str) -> Model: model = await self.get_object_by_identifier("model", model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") return model async def register_model( self, model_id: str, provider_model_id: str | None = None, provider_id: str | None = None, metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, ) -> Model: if provider_model_id is None: provider_model_id = model_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this model if len(self.impls_by_provider_id) == 1: provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" ) if metadata is None: metadata = {} if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") model = ModelWithOwner( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, model_type=model_type, ) registered_model = await self.register_object(model) return registered_model async def unregister_model(self, model_id: str) -> None: existing_model = await self.get_model(model_id) if existing_model is None: raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) async def update_registered_llm_models( self, provider_id: str, models: list[Model], ) -> None: existing_models = await self.get_all_with_type("model") # we may have an alias for the model registered by the user (or during initialization # from run.yaml) that we need to keep track of model_ids = {} for model in existing_models: # we leave embeddings models alone because often we don't get metadata # (embedding dimension, etc.) from the provider if model.provider_id == provider_id and model.model_type == ModelType.llm: model_ids[model.provider_resource_id] = model.identifier logger.debug(f"unregistering model {model.identifier}") await self.unregister_object(model) for model in models: if model.model_type != ModelType.llm: continue if model.provider_resource_id in model_ids: model.identifier = model_ids[model.provider_resource_id] logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") await self.register_object( ModelWithOwner( identifier=model.identifier, provider_resource_id=model.provider_resource_id, provider_id=provider_id, metadata=model.metadata, model_type=model.model_type, ) )