From d035fe93c66c76d67d57dde1aad61ea7c3d4e87f Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 9 Jul 2025 12:16:03 -0400 Subject: [PATCH 1/3] feat: add infrastructure to allow inference model discovery inference providers each have a static list of supported / known models. some also have access to a dynamic list of currently available models. this change gives prodivers using the ModelRegistryHelper the ability to combine their static and dynamic lists. for instance, OpenAIInferenceAdapter can implement ``` def query_available_models(self) -> list[str]: return [entry.model for entry in self.openai_client.models.list()] ``` to augment its static list w/ a current list from openai. --- .../utils/inference/model_registry.py | 33 ++++++- .../providers/utils/test_model_registry.py | 94 +++++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 46c0ca7b5..f4f28c1f3 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -82,9 +82,37 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def get_llama_model(self, provider_model_id: str) -> str | None: return self.provider_id_to_llama_model_map.get(provider_model_id, None) + async def query_available_models(self) -> list[str]: + """ + Return a list of available models. + + This is for subclassing purposes, so providers can lookup a list of + of currently available models. + + This is combined with the statically configured model entries in + `self.alias_to_provider_id_map` to determine which models are + available for registration. + + Default implementation returns no models. + + :return: A list of model identifiers (provider_model_ids). + """ + return [] + async def register_model(self, model: Model) -> Model: - if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)): - raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) + # Check if model is supported in static configuration + supported_model_id = self.get_provider_model_id(model.provider_resource_id) + + # If not found in static config, check if it's available from provider + if not supported_model_id: + available_models = await self.query_available_models() + if model.provider_resource_id in available_models: + supported_model_id = model.provider_resource_id + else: + # Combine static and dynamic models for error message + all_supported_models = list(self.alias_to_provider_id_map.keys()) + available_models + raise UnsupportedModelError(model.provider_resource_id, all_supported_models) + provider_resource_id = self.get_provider_model_id(model.model_id) if model.model_type == ModelType.embedding: # embedding models are always registered by their provider model id and does not need to be mapped to a llama model @@ -113,6 +141,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ) + # Register the model alias, ensuring it maps to the correct provider model id self.alias_to_provider_id_map[model.model_id] = supported_model_id return model diff --git a/tests/unit/providers/utils/test_model_registry.py b/tests/unit/providers/utils/test_model_registry.py index 10fa1e075..c7f7eb299 100644 --- a/tests/unit/providers/utils/test_model_registry.py +++ b/tests/unit/providers/utils/test_model_registry.py @@ -87,6 +87,37 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov return ModelRegistryHelper([known_provider_model, known_provider_model2]) +class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper): + """Test helper that simulates a provider with dynamically available models.""" + + def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]): + super().__init__(model_entries) + self._available_models = available_models + + async def query_available_models(self) -> list[str]: + return self._available_models + + +@pytest.fixture +def dynamic_model() -> Model: + """A model that's not in static config but available dynamically.""" + return Model( + provider_id="provider", + identifier="dynamic-model", + provider_resource_id="dynamic-provider-id", + ) + + +@pytest.fixture +def helper_with_dynamic_models( + known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model +) -> MockModelRegistryHelperWithDynamicModels: + """Helper that includes dynamically available models.""" + return MockModelRegistryHelperWithDynamicModels( + [known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id] + ) + + @pytest.mark.asyncio async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: assert helper.get_provider_model_id(unknown_model.model_id) is None @@ -161,3 +192,66 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id await helper.unregister_model(known_model.provider_resource_id) assert helper.get_provider_model_id(known_model.provider_resource_id) is None + + +@pytest.mark.asyncio +async def test_register_model_from_query_available_models( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model +) -> None: + """Test that models returned by query_available_models can be registered.""" + # Verify the model is not in static config + assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None + + # But it should be available via query_available_models + available_models = await helper_with_dynamic_models.query_available_models() + assert dynamic_model.provider_resource_id in available_models + + # Registration should succeed + registered_model = await helper_with_dynamic_models.register_model(dynamic_model) + assert registered_model == dynamic_model + + # Model should now be registered and accessible + assert ( + helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id + ) + + +@pytest.mark.asyncio +async def test_register_model_not_in_static_or_dynamic( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model +) -> None: + """Test that models not in static config or dynamic models are rejected.""" + # Verify the model is not in static config + assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None + + # And not in dynamic models + available_models = await helper_with_dynamic_models.query_available_models() + assert unknown_model.provider_resource_id not in available_models + + # Registration should fail with comprehensive error message + with pytest.raises(Exception) as exc_info: # UnsupportedModelError + await helper_with_dynamic_models.register_model(unknown_model) + + # Error should include both static and dynamic models + error_str = str(exc_info.value) + assert "dynamic-provider-id" in error_str # dynamic model should be in error + + +@pytest.mark.asyncio +async def test_register_alias_for_dynamic_model( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model +) -> None: + """Test that we can register an alias that maps to a dynamically available model.""" + # Create a model with a different identifier but same provider_resource_id + alias_model = Model( + provider_id=dynamic_model.provider_id, + identifier="dynamic-model-alias", + provider_resource_id=dynamic_model.provider_resource_id, + ) + + # Registration should succeed since the provider_resource_id is available dynamically + registered_model = await helper_with_dynamic_models.register_model(alias_model) + assert registered_model == alias_model + + # Both the original provider_resource_id and the new alias should work + assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id From 51787a93f62a44a910e5727d76e3155ba9c0d960 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 11 Jul 2025 10:08:49 -0400 Subject: [PATCH 2/3] feat: allow dynamic model registration for ollama inference provider implements query_available_models on OllamaInferenceAdapter --- .../remote/inference/ollama/ollama.py | 60 ++++++++----------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 010e346bd..125ce7ac8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -19,7 +19,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContentItem, TextContentItem, ) -from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -54,7 +53,6 @@ from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( HealthResponse, HealthStatus, - ModelsProtocolPrivate, ) from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig from llama_stack.providers.utils.inference.model_registry import ( @@ -89,10 +87,10 @@ logger = get_logger(name=__name__, category="inference") class OllamaInferenceAdapter( InferenceProvider, - ModelsProtocolPrivate, + ModelRegistryHelper, ): def __init__(self, config: OllamaImplConfig) -> None: - self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.url = config.url @property @@ -123,6 +121,27 @@ class OllamaInferenceAdapter( except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") + async def query_available_models(self) -> list[str]: + """ + Query Ollama for available models. + + Ollama allows omitting the `:latest` suffix, so we include some-name:latest as some-name and some-name:latest. + + :return: A list of model identifiers (provider_model_ids). + """ + available_models = [] + try: + # we use list() here instead of ps() - + # - ps() only lists running models, not available models + # - models not currently running are run by the ollama server as needed + for m in (await self.client.list()).models: + available_models.append(m.model) + if m.model.endswith(":latest"): + available_models.append(m.model[: -len(":latest")]) + except Exception as e: + logger.warning(f"Failed to query available models from Ollama: {e}") + return available_models + async def shutdown(self) -> None: pass @@ -237,7 +256,7 @@ class OllamaInferenceAdapter( input_dict: dict[str, Any] = {} media_present = request_has_media(request) - llama_model = self.register_helper.get_llama_model(request.model) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): if media_present or not llama_model: contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] @@ -345,40 +364,13 @@ class OllamaInferenceAdapter( return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: - try: - model = await self.register_helper.register_model(model) - except ValueError: - pass # Ignore statically unknown model, will check live listing - - if model.provider_resource_id is None: - raise ValueError("Model provider_resource_id cannot be None") - if model.model_type == ModelType.embedding: logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") # TODO: you should pull here only if the model is not found in a list - response = await self.client.list() - if model.provider_resource_id not in [m.model for m in response.models]: + if model.provider_resource_id not in await self.query_available_models(): await self.client.pull(model.provider_resource_id) - # we use list() here instead of ps() - - # - ps() only lists running models, not available models - # - models not currently running are run by the ollama server as needed - response = await self.client.list() - available_models = [m.model for m in response.models] - provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) - if provider_resource_id is None: - provider_resource_id = model.provider_resource_id - if provider_resource_id not in available_models: - available_models_latest = [m.model.split(":latest")[0] for m in response.models] - if provider_resource_id in available_models_latest: - logger.warning( - f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" - ) - return model - raise UnsupportedModelError(model.provider_resource_id, available_models) - model.provider_resource_id = provider_resource_id - - return model + return await ModelRegistryHelper.register_model(self, model) async def openai_embeddings( self, From 89b10528069dde8d4cdbfb5dd5521c931836644a Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 14 Jul 2025 17:42:37 -0400 Subject: [PATCH 3/3] query_available_models() -> list[str] -> check_model_availability(model) -> bool --- .../remote/inference/ollama/ollama.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 125ce7ac8..bd36482df 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -121,7 +121,21 @@ class OllamaInferenceAdapter( except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") - async def query_available_models(self) -> list[str]: + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available in Ollama. + + :param model: The model identifier to check. + :return: True if the model is available, False otherwise. + """ + try: + available_models = await self._query_available_models() + return model in available_models + except Exception as e: + logger.error(f"Error checking model availability: {e}") + return False + + async def _query_available_models(self) -> list[str]: """ Query Ollama for available models. @@ -365,9 +379,10 @@ class OllamaInferenceAdapter( async def register_model(self, model: Model) -> Model: if model.model_type == ModelType.embedding: + assert model.provider_resource_id, "Embedding models must have a provider_resource_id set" logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") # TODO: you should pull here only if the model is not found in a list - if model.provider_resource_id not in await self.query_available_models(): + if not await self.check_model_availability(model.provider_resource_id): await self.client.pull(model.provider_resource_id) return await ModelRegistryHelper.register_model(self, model)