diff --git a/docs/source/providers/inference/remote_fireworks.md b/docs/source/providers/inference/remote_fireworks.md index 351586c34..862860c29 100644 --- a/docs/source/providers/inference/remote_fireworks.md +++ b/docs/source/providers/inference/remote_fireworks.md @@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | diff --git a/docs/source/providers/inference/remote_ollama.md b/docs/source/providers/inference/remote_ollama.md index 23b8f87a2..f9f0a7622 100644 --- a/docs/source/providers/inference/remote_ollama.md +++ b/docs/source/providers/inference/remote_ollama.md @@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `url` | `` | No | http://localhost:11434 | | -| `refresh_models` | `` | No | False | refresh and re-register models periodically | -| `refresh_models_interval` | `` | No | 300 | interval in seconds to refresh models | +| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/source/providers/inference/remote_together.md b/docs/source/providers/inference/remote_together.md index f33ff42f2..d1fe3e82b 100644 --- a/docs/source/providers/inference/remote_together.md +++ b/docs/source/providers/inference/remote_together.md @@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | diff --git a/docs/source/providers/inference/remote_vllm.md b/docs/source/providers/inference/remote_vllm.md index 5291199a4..172d35873 100644 --- a/docs/source/providers/inference/remote_vllm.md +++ b/docs/source/providers/inference/remote_vllm.md @@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | -| `refresh_models_interval` | `` | No | 300 | Interval in seconds to refresh models | ## Sample Configuration diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b2bb8a8e6..222099064 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... - async def update_registered_llm_models( - self, - provider_id: str, - models: list[Model], - ) -> None: ... - class TextTruncation(Enum): """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ead1331f3..90b269452 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2 RoutingKey = str | list[str] +class RegistryEntrySource(StrEnum): + via_register_api = "via_register_api" + listed_from_provider = "listed_from_provider" + + class User(BaseModel): principal: str # further attributes that may be used for access control decisions @@ -50,6 +55,7 @@ class ResourceWithOwner(Resource): resource. This can be used to constrain access to the resource.""" owner: User | None = None + source: RegistryEntrySource = RegistryEntrySource.via_register_api # Use the extended Resource for all routable objects diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 6503c13b2..5044fd8c8 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient): if not self.skip_logger_removal: self._remove_root_logger_handlers() - return self.loop.run_until_complete(self.async_client.initialize()) + # use a new event loop to avoid interfering with the main event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self.async_client.initialize()) + finally: + asyncio.set_event_loop(None) def _remove_root_logger_handlers(self): """ diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 2f6ac90bb..caf0780fd 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() + async def refresh(self) -> None: + pass + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable @@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable): if obj.type == ResourceType.model.value: await self.dist_registry.register(registered_obj) return registered_obj - else: await self.dist_registry.register(obj) return obj diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index f2787b308..022c3dd40 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -10,6 +10,7 @@ from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( ModelWithOwner, + RegistryEntrySource, ) from llama_stack.log import get_logger @@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core") class ModelsRoutingTable(CommonRoutingTableImpl, Models): + listed_providers: set[str] = set() + + async def refresh(self) -> None: + for provider_id, provider in self.impls_by_provider_id.items(): + refresh = await provider.should_refresh_models() + if not (refresh or provider_id in self.listed_providers): + continue + + try: + models = await provider.list_models() + except Exception as e: + logger.exception(f"Model refresh failed for provider {provider_id}: {e}") + continue + + self.listed_providers.add(provider_id) + if models is None: + continue + + await self.update_registered_models(provider_id, models) + async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) @@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, model_type=model_type, + source=RegistryEntrySource.via_register_api, ) registered_model = await self.register_object(model) return registered_model @@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) - async def update_registered_llm_models( + async def update_registered_models( self, provider_id: str, models: list[Model], @@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): # from run.yaml) that we need to keep track of model_ids = {} for model in existing_models: - # we leave embeddings models alone because often we don't get metadata - # (embedding dimension, etc.) from the provider - if model.provider_id == provider_id and model.model_type == ModelType.llm: + if model.provider_id != provider_id: + continue + if model.source == RegistryEntrySource.via_register_api: model_ids[model.provider_resource_id] = model.identifier - logger.debug(f"unregistering model {model.identifier}") - await self.unregister_object(model) + continue + + logger.debug(f"unregistering model {model.identifier}") + await self.unregister_object(model) for model in models: - if model.model_type != ModelType.llm: - continue if model.provider_resource_id in model_ids: - model.identifier = model_ids[model.provider_resource_id] + # avoid overwriting a non-provider-registered model entry + continue logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") await self.register_object( @@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=model.metadata, model_type=model.model_type, + source=RegistryEntrySource.listed_from_provider, ) ) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index d7270156a..57bc4cd5f 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import importlib.resources import os import re @@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls +from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -90,6 +92,9 @@ RESOURCES = [ ] +REGISTRY_REFRESH_INTERVAL_SECONDS = 300 + + async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) @@ -324,9 +329,33 @@ async def construct_stack( add_internal_implementations(impls, run_config) await register_resources(run_config, impls) + + task = asyncio.create_task(refresh_registry(impls)) + + def cb(task): + import traceback + + if task.cancelled(): + logger.error("Model refresh task cancelled") + elif task.exception(): + logger.error(f"Model refresh task failed: {task.exception()}") + traceback.print_exception(task.exception()) + else: + logger.debug("Model refresh task completed") + + task.add_done_callback(cb) return impls +async def refresh_registry(impls: dict[Api, Any]): + routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] + while True: + for routing_table in routing_tables: + await routing_table.refresh() + + await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) + + def get_stack_run_config_from_template(template: str) -> StackRunConfig: template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 424380324..005bfbab8 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol): async def unregister_model(self, model_id: str) -> None: ... + # the Stack router will query each provider for their list of models + # if a `refresh_interval_seconds` is provided, this method will be called + # periodically to refresh the list of models + # + # NOTE: each model returned will be registered with the model registry. this means + # a callback to the `register_model()` method will be made. this is duplicative and + # may be removed in the future. + async def list_models(self) -> list[Model] | None: ... + + async def should_refresh_models(self) -> bool: ... + class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e238e1b78..88d7a98ec 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator.stop() + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + return None + async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 890c526f5..fea8a8189 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl( InferenceProvider, ModelsProtocolPrivate, ): + __provider_id__: str + def __init__(self, config: SentenceTransformersInferenceConfig) -> None: self.config = config @@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl( async def shutdown(self) -> None: pass + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + return [ + Model( + identifier="all-MiniLM-L6-v2", + provider_resource_id="all-MiniLM-L6-v2", + provider_id=self.__provider_id__, + metadata={ + "embedding_dimension": 384, + }, + model_type=ModelType.embedding, + ), + ] + async def register_model(self, model: Model) -> Model: return model diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 072d558f4..b23f2d31b 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class FireworksImplConfig(BaseModel): +class FireworksImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 1c82ff3a8..c76aa39f3 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index ae261f47c..ce13f0d83 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(BaseModel): url: str = DEFAULT_OLLAMA_URL - refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") - refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically", + ) @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 76d789d07..ba20185d3 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -98,14 +98,16 @@ class OllamaInferenceAdapter( def __init__(self, config: OllamaImplConfig) -> None: self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.config = config - self._client = None + self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {} self._openai_client = None @property def client(self) -> AsyncClient: - if self._client is None: - self._client = AsyncClient(host=self.config.url) - return self._client + # ollama client attaches itself to the current event loop (sadly?) + loop = asyncio.get_running_loop() + if loop not in self._clients: + self._clients[loop] = AsyncClient(host=self.config.url) + return self._clients[loop] @property def openai_client(self) -> AsyncOpenAI: @@ -121,59 +123,61 @@ class OllamaInferenceAdapter( "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" ) - if self.config.refresh_models: - logger.debug("ollama starting background model refresh task") - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - if task.cancelled(): - import traceback - - logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - logger.error(f"ollama background refresh task died: {task.exception()}") - else: - logger.error("ollama background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - # Wait for model store to be available (with timeout) - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: provider_id = self.__provider_id__ - while True: - try: - response = await self.client.list() - except Exception as e: - logger.warning(f"Failed to list models: {str(e)}") - await asyncio.sleep(self.config.refresh_models_interval) + response = await self.client.list() + + # always add the two embedding models which can be pulled on demand + models = [ + Model( + identifier="all-minilm:l6-v2", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + # add all-minilm alias + Model( + identifier="all-minilm", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + Model( + identifier="nomic-embed-text", + provider_resource_id="nomic-embed-text", + provider_id=provider_id, + metadata={ + "embedding_dimension": 768, + "context_length": 8192, + }, + model_type=ModelType.embedding, + ), + ] + for m in response.models: + # kill embedding models since we don't know dimensions for them + if m.details.family in ["bert"]: continue - - models = [] - for m in response.models: - model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm - if model_type == ModelType.embedding: - continue - models.append( - Model( - identifier=m.model, - provider_resource_id=m.model, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) + models.append( + Model( + identifier=m.model, + provider_resource_id=m.model, + provider_id=provider_id, + metadata={}, + model_type=ModelType.llm, ) - await self.model_store.update_registered_llm_models(provider_id, models) - logger.debug(f"ollama refreshed model list ({len(models)} models)") - - await asyncio.sleep(self.config.refresh_models_interval) + ) + return models async def health(self) -> HealthResponse: """ @@ -190,12 +194,7 @@ class OllamaInferenceAdapter( return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") async def shutdown(self) -> None: - if hasattr(self, "_refresh_task") and not self._refresh_task.done(): - logger.debug("ollama cancelling background refresh task") - self._refresh_task.cancel() - - self._client = None - self._openai_client = None + self._clients.clear() async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index f166e4277..211be7efe 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class TogetherImplConfig(BaseModel): +class TogetherImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.together.xyz/v1", description="The URL for the Together AI server", diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e1eb934c5..46094c146 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index ee72f974a..a5bf0e4bc 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel): default=False, description="Whether to refresh models periodically", ) - refresh_models_interval: int = Field( - default=300, - description="Interval in seconds to refresh models", - ) @field_validator("tls_verify") @classmethod diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8bdba1e88..621658a48 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import json from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None - _refresh_task: asyncio.Task | None = None def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) @@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - if not self.config.url: - # intentionally don't raise an error here, we want to allow the provider to be "dormant" - # or available in distributions like "starter" without causing a ruckus - return + pass - if self.config.refresh_models: - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - import traceback - - if task.cancelled(): - log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - # print the stack trace for the exception - exc = task.exception() - log.error(f"vLLM background refresh task died: {exc}") - traceback.print_exception(exc) - else: - log.error("vLLM background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - provider_id = self.__provider_id__ - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: self._lazy_initialize_client() assert self.client is not None # mypy - while True: - try: - models = [] - async for m in self.client.models.list(): - model_type = ModelType.llm # unclear how to determine embedding vs. llm models - models.append( - Model( - identifier=m.id, - provider_resource_id=m.id, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) - ) - await self.model_store.update_registered_llm_models(provider_id, models) - log.debug(f"vLLM refreshed model list ({len(models)} models)") - except Exception as e: - log.error(f"vLLM background refresh task failed: {e}") - await asyncio.sleep(self.config.refresh_models_interval) + models = [] + async for m in self.client.models.list(): + model_type = ModelType.llm # unclear how to determine embedding vs. llm models + models.append( + Model( + identifier=m.id, + provider_resource_id=m.id, + provider_id=self.__provider_id__, + metadata={}, + model_type=model_type, + ) + ) + return models async def shutdown(self) -> None: - if self._refresh_task: - self._refresh_task.cancel() - self._refresh_task = None + pass async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 651d58e2a..bceeaf198 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import ( logger = get_logger(name=__name__, category="core") +class RemoteInferenceProviderConfig(BaseModel): + allowed_models: list[str] | None = Field( + default=None, + description="List of models that should be registered with the model registry. If None, all models are allowed.", + ) + + # TODO: this class is more confusing than useful right now. We need to make it # more closer to the Model class. class ProviderModelEntry(BaseModel): @@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider class ModelRegistryHelper(ModelsProtocolPrivate): - def __init__(self, model_entries: list[ProviderModelEntry]): + __provider_id__: str + + def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None): + self.allowed_models = allowed_models self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for entry in model_entries: @@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model + async def list_models(self) -> list[Model] | None: + models = [] + for entry in self.model_entries: + ids = [entry.provider_model_id] + entry.aliases + for id in ids: + if self.allowed_models and id not in self.allowed_models: + continue + models.append( + Model( + model_id=id, + provider_resource_id=entry.provider_model_id, + model_type=ModelType.llm, + metadata=entry.metadata, + provider_id=self.__provider_id__, + ) + ) + return models + + async def should_refresh_models(self) -> bool: + return False + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 12b05ebff..c1b57cb4f 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.distribution.datatypes import RegistryEntrySource from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable @@ -45,6 +46,30 @@ class InferenceImpl(Impl): async def unregister_model(self, model_id: str): return model_id + async def should_refresh_models(self): + return False + + async def list_models(self): + return [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + + async def shutdown(self): + pass + class SafetyImpl(Impl): def __init__(self): @@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): raise AssertionError("Should have raised ValueError for non-existent model") except ValueError as e: assert "not found" in str(e) + + +async def test_models_source_tracking_default(cached_disk_dist_registry): + """Test that models registered via register_model get default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model via register_model (should get default source) + await table.register_model(model_id="user-model", provider_id="test_provider") + + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.source == RegistryEntrySource.via_register_api + assert model.identifier == "test_provider/user-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_tracking_provider(cached_disk_dist_registry): + """Test that models registered via update_registered_models get provider source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Simulate provider refresh by calling update_registered_models + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + models = await table.list_models() + assert len(models.data) == 2 + + # All models should have provider source + for model in models.data: + assert model.source == RegistryEntrySource.listed_from_provider + assert model.provider_id == "test_provider" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_preserves_default(cached_disk_dist_registry): + """Test that provider refresh preserves user-registered models with default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # First register a user model with same provider_resource_id as provider will later provide + await table.register_model( + model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider" + ) + + # Verify user model is registered with default source + models = await table.list_models() + assert len(models.data) == 1 + user_model = models.data[0] + assert user_model.source == RegistryEntrySource.via_register_api + assert user_model.identifier == "my-custom-alias" + assert user_model.provider_resource_id == "provider-model-1" + + # Now simulate provider refresh + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="different-model", + provider_resource_id="different-model", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + # Verify user model with alias is preserved, but provider added new model + models = await table.list_models() + assert len(models.data) == 2 + + # Find the user model and provider model + user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) + provider_model = next((m for m in models.data if m.identifier == "different-model"), None) + + assert user_model is not None + assert user_model.source == RegistryEntrySource.via_register_api + assert user_model.provider_resource_id == "provider-model-1" + + assert provider_model is not None + assert provider_model.source == RegistryEntrySource.listed_from_provider + assert provider_model.provider_resource_id == "different-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry): + """Test that provider refresh removes old provider models but keeps default ones.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a user model + await table.register_model(model_id="user-model", provider_id="test_provider") + + # Add some provider models + provider_models_v1 = [ + Model( + identifier="provider-model-old", + provider_resource_id="provider-model-old", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v1) + + # Verify we have both user and provider models + models = await table.list_models() + assert len(models.data) == 2 + + # Now update with new provider models (should remove old provider models) + provider_models_v2 = [ + Model( + identifier="provider-model-new", + provider_resource_id="provider-model-new", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v2) + + # Should have user model + new provider model, old provider model gone + models = await table.list_models() + assert len(models.data) == 2 + + identifiers = {m.identifier for m in models.data} + assert "test_provider/user-model" in identifiers # User model preserved + assert "provider-model-new" in identifiers # New provider model (uses provider's identifier) + assert "provider-model-old" not in identifiers # Old provider model removed + + # Verify sources are correct + user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) + provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None) + + assert user_model.source == RegistryEntrySource.via_register_api + assert provider_model.source == RegistryEntrySource.listed_from_provider + + # Cleanup + await table.shutdown()