feat: allow dynamic model registration for ollama inference provider

implements query_available_models on OllamaInferenceAdapter
This commit is contained in:
Matthew Farrellee 2025-07-11 10:08:49 -04:00
parent d035fe93c6
commit 51787a93f6

View file

@ -19,7 +19,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem, InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -54,7 +53,6 @@ from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
HealthResponse, HealthResponse,
HealthStatus, HealthStatus,
ModelsProtocolPrivate,
) )
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
@ -89,10 +87,10 @@ logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelRegistryHelper,
): ):
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.url = config.url self.url = config.url
@property @property
@ -123,6 +121,27 @@ class OllamaInferenceAdapter(
except Exception as e: except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def query_available_models(self) -> list[str]:
"""
Query Ollama for available models.
Ollama allows omitting the `:latest` suffix, so we include some-name:latest as some-name and some-name:latest.
:return: A list of model identifiers (provider_model_ids).
"""
available_models = []
try:
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
for m in (await self.client.list()).models:
available_models.append(m.model)
if m.model.endswith(":latest"):
available_models.append(m.model[: -len(":latest")])
except Exception as e:
logger.warning(f"Failed to query available models from Ollama: {e}")
return available_models
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -237,7 +256,7 @@ class OllamaInferenceAdapter(
input_dict: dict[str, Any] = {} input_dict: dict[str, Any] = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model) llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model: if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
@ -345,40 +364,13 @@ class OllamaInferenceAdapter(
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
# TODO: you should pull here only if the model is not found in a list # TODO: you should pull here only if the model is not found in a list
response = await self.client.list() if model.provider_resource_id not in await self.query_available_models():
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id) await self.client.pull(model.provider_resource_id)
# we use list() here instead of ps() - return await ModelRegistryHelper.register_model(self, model)
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m.model for m in response.models]
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
if provider_resource_id not in available_models:
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise UnsupportedModelError(model.provider_resource_id, available_models)
model.provider_resource_id = provider_resource_id
return model
async def openai_embeddings( async def openai_embeddings(
self, self,