diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py index d01c3a54d..e0bf269db 100644 --- a/llama_stack/providers/remote/inference/ollama/models.py +++ b/llama_stack/providers/remote/inference/ollama/models.py @@ -84,7 +84,8 @@ model_entries = [ CoreModelId.llama_guard_3_1b.value, ), ProviderModelEntry( - provider_model_id="all-minilm", + provider_model_id="all-minilm:latest", + aliases=["all-minilm"], model_type=ModelType.embedding, metadata={ "embedding_dimensions": 384, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 2cba3c668..1dbcbc294 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -274,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: - async def check_model_availability(model_id: str): - response = await self.client.ps() - available_models = [m["model"] for m in response["models"]] - if model_id not in available_models: - raise ValueError( - f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}" - ) - if model.model_type == ModelType.embedding: - await check_model_availability(model.provider_resource_id) - return model + response = await self.client.list() + else: + response = await self.client.ps() + available_models = [m["model"] for m in response["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" + ) - model = await self.register_helper.register_model(model) - await check_model_availability(model.provider_resource_id) - - return model + return await self.register_helper.register_model(model) async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 2026f7f8a..0882019e3 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -56,14 +56,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def __init__(self, model_entries: List[ProviderModelEntry]): self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} - for alias_obj in model_entries: - for alias in alias_obj.aliases: - self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id + for entry in model_entries: + for alias in entry.aliases: + self.alias_to_provider_id_map[alias] = entry.provider_model_id + # also add a mapping from provider model id to itself for easy lookup - self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id - # ensure we can go from llama model to provider model id - self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id - self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model + self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id + + if entry.llama_model: + self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id + self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model def get_provider_model_id(self, identifier: str) -> Optional[str]: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index 79825e990..31119e040 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -72,7 +72,7 @@ def get_distribution_template() -> DistributionTemplate: embedding_model = ModelInput( model_id="all-MiniLM-L6-v2", provider_id="ollama", - provider_model_id="all-minilm", + provider_model_id="all-minilm:latest", model_type=ModelType.embedding, metadata={ "embedding_dimension": 384, diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index f5ae29cfc..7cf527c04 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -111,7 +111,7 @@ models: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: ollama - provider_model_id: all-minilm + provider_model_id: all-minilm:latest model_type: embedding shields: - shield_id: ${env.SAFETY_MODEL} diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index a6ba60ce7..ab292c5e0 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -104,7 +104,7 @@ models: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: ollama - provider_model_id: all-minilm + provider_model_id: all-minilm:latest model_type: embedding shields: [] vector_dbs: []