mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
181 lines
6.8 KiB
Python
181 lines
6.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import time
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
|
from llama_stack.distribution.datatypes import (
|
|
ModelWithOwner,
|
|
RegistryEntrySource,
|
|
)
|
|
from llama_stack.log import get_logger
|
|
|
|
from .common import CommonRoutingTableImpl, lookup_model
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
listed_providers: set[str] = set()
|
|
model_refresh_interval_seconds: int = 300
|
|
_refresh_task: asyncio.Task | None = None
|
|
|
|
async def initialize(self) -> None:
|
|
await super().initialize()
|
|
task = asyncio.create_task(self._refresh_models())
|
|
self._refresh_task = task
|
|
|
|
def cb(task):
|
|
import traceback
|
|
|
|
if task.cancelled():
|
|
logger.error("Model refresh task cancelled")
|
|
elif task.exception():
|
|
logger.error(f"Model refresh task failed: {task.exception()}")
|
|
traceback.print_exception(task.exception())
|
|
else:
|
|
logger.debug("Model refresh task completed")
|
|
|
|
task.add_done_callback(cb)
|
|
|
|
async def shutdown(self) -> None:
|
|
await super().shutdown()
|
|
if self._refresh_task:
|
|
self._refresh_task.cancel()
|
|
self._refresh_task = None
|
|
|
|
async def _refresh_models(self) -> None:
|
|
while True:
|
|
for provider_id, provider in self.impls_by_provider_id.items():
|
|
refresh = await provider.should_refresh_models()
|
|
if not (refresh or provider_id in self.listed_providers):
|
|
continue
|
|
|
|
try:
|
|
models = await provider.list_models()
|
|
except Exception as e:
|
|
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
|
continue
|
|
|
|
self.listed_providers.add(provider_id)
|
|
if models is None:
|
|
continue
|
|
|
|
await self.update_registered_models(provider_id, models)
|
|
|
|
await asyncio.sleep(self.model_refresh_interval_seconds)
|
|
|
|
async def list_models(self) -> ListModelsResponse:
|
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
|
|
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
|
models = await self.get_all_with_type("model")
|
|
openai_models = [
|
|
OpenAIModel(
|
|
id=model.identifier,
|
|
object="model",
|
|
created=int(time.time()),
|
|
owned_by="llama_stack",
|
|
)
|
|
for model in models
|
|
]
|
|
return OpenAIListModelsResponse(data=openai_models)
|
|
|
|
async def get_model(self, model_id: str) -> Model:
|
|
return await lookup_model(self, model_id)
|
|
|
|
async def get_provider_impl(self, model_id: str) -> Any:
|
|
model = await lookup_model(self, model_id)
|
|
return self.impls_by_provider_id[model.provider_id]
|
|
|
|
async def register_model(
|
|
self,
|
|
model_id: str,
|
|
provider_model_id: str | None = None,
|
|
provider_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model_type: ModelType | None = None,
|
|
) -> Model:
|
|
if provider_id is None:
|
|
# If provider_id not specified, use the only provider if it supports this model
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
|
|
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
|
)
|
|
|
|
provider_model_id = provider_model_id or model_id
|
|
metadata = metadata or {}
|
|
model_type = model_type or ModelType.llm
|
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
|
|
|
# an identifier different than provider_model_id implies it is an alias, so that
|
|
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
|
# so as a general rule we must use the provider_id to disambiguate.
|
|
|
|
if model_id != provider_model_id:
|
|
identifier = model_id
|
|
else:
|
|
identifier = f"{provider_id}/{provider_model_id}"
|
|
|
|
model = ModelWithOwner(
|
|
identifier=identifier,
|
|
provider_resource_id=provider_model_id,
|
|
provider_id=provider_id,
|
|
metadata=metadata,
|
|
model_type=model_type,
|
|
source=RegistryEntrySource.default,
|
|
)
|
|
registered_model = await self.register_object(model)
|
|
return registered_model
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
existing_model = await self.get_model(model_id)
|
|
if existing_model is None:
|
|
raise ValueError(f"Model {model_id} not found")
|
|
await self.unregister_object(existing_model)
|
|
|
|
async def update_registered_models(
|
|
self,
|
|
provider_id: str,
|
|
models: list[Model],
|
|
) -> None:
|
|
existing_models = await self.get_all_with_type("model")
|
|
|
|
# we may have an alias for the model registered by the user (or during initialization
|
|
# from run.yaml) that we need to keep track of
|
|
model_ids = {}
|
|
for model in existing_models:
|
|
if model.provider_id != provider_id:
|
|
continue
|
|
if model.source == RegistryEntrySource.default:
|
|
model_ids[model.provider_resource_id] = model.identifier
|
|
continue
|
|
|
|
logger.debug(f"unregistering model {model.identifier}")
|
|
await self.unregister_object(model)
|
|
|
|
for model in models:
|
|
if model.provider_resource_id in model_ids:
|
|
# avoid overwriting a non-provider-registered model entry
|
|
continue
|
|
|
|
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
|
await self.register_object(
|
|
ModelWithOwner(
|
|
identifier=model.identifier,
|
|
provider_resource_id=model.provider_resource_id,
|
|
provider_id=provider_id,
|
|
metadata=model.metadata,
|
|
model_type=model.model_type,
|
|
source=RegistryEntrySource.provider,
|
|
)
|
|
)
|