diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index ba20185d3..c2fdc74e1 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -20,7 +20,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContentItem, TextContentItem, ) -from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -55,7 +54,6 @@ from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( HealthResponse, HealthStatus, - ModelsProtocolPrivate, ) from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig from llama_stack.providers.utils.inference.model_registry import ( @@ -90,13 +88,13 @@ logger = get_logger(name=__name__, category="inference") class OllamaInferenceAdapter( InferenceProvider, - ModelsProtocolPrivate, + ModelRegistryHelper, ): # automatically set by the resolver when instantiating the provider __provider_id__: str def __init__(self, config: OllamaImplConfig) -> None: - self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {} self._openai_client = None @@ -193,6 +191,41 @@ class OllamaInferenceAdapter( except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available in Ollama. + + :param model: The model identifier to check. + :return: True if the model is available, False otherwise. + """ + try: + available_models = await self._query_available_models() + return model in available_models + except Exception as e: + logger.error(f"Error checking model availability: {e}") + return False + + async def _query_available_models(self) -> list[str]: + """ + Query Ollama for available models. + + Ollama allows omitting the `:latest` suffix, so we include some-name:latest as some-name and some-name:latest. + + :return: A list of model identifiers (provider_model_ids). + """ + available_models = [] + try: + # we use list() here instead of ps() - + # - ps() only lists running models, not available models + # - models not currently running are run by the ollama server as needed + for m in (await self.client.list()).models: + available_models.append(m.model) + if m.model.endswith(":latest"): + available_models.append(m.model[: -len(":latest")]) + except Exception as e: + logger.warning(f"Failed to query available models from Ollama: {e}") + return available_models + async def shutdown(self) -> None: self._clients.clear() @@ -307,7 +340,7 @@ class OllamaInferenceAdapter( input_dict: dict[str, Any] = {} media_present = request_has_media(request) - llama_model = self.register_helper.get_llama_model(request.model) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): if media_present or not llama_model: contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] @@ -415,38 +448,14 @@ class OllamaInferenceAdapter( return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: - try: - model = await self.register_helper.register_model(model) - except ValueError: - pass # Ignore statically unknown model, will check live listing - - if model.provider_resource_id is None: - raise ValueError("Model provider_resource_id cannot be None") - if model.model_type == ModelType.embedding: - response = await self.client.list() - if model.provider_resource_id not in [m.model for m in response.models]: + assert model.provider_resource_id, "Embedding models must have a provider_resource_id set" + logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") + # TODO: you should pull here only if the model is not found in a list + if not await self.check_model_availability(model.provider_resource_id): await self.client.pull(model.provider_resource_id) - # we use list() here instead of ps() - - # - ps() only lists running models, not available models - # - models not currently running are run by the ollama server as needed - response = await self.client.list() - available_models = [m.model for m in response.models] - provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) - if provider_resource_id is None: - provider_resource_id = model.provider_resource_id - if provider_resource_id not in available_models: - available_models_latest = [m.model.split(":latest")[0] for m in response.models] - if provider_resource_id in available_models_latest: - logger.warning( - f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" - ) - return model - raise UnsupportedModelError(model.provider_resource_id, available_models) - model.provider_resource_id = provider_resource_id - - return model + return await ModelRegistryHelper.register_model(self, model) async def openai_embeddings( self,