diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index bd267a5f8..402417ed3 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -9,7 +9,6 @@ from typing import AsyncGenerator import httpx from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -31,11 +30,12 @@ OLLAMA_SUPPORTED_SKUS = { class OllamaInferenceAdapter(Inference, RoutableProviderForModels): - def __init__(self, url: str) -> None: + def __init__(self, url: str, preload_models: List[str] = None) -> None: RoutableProviderForModels.__init__( self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS ) self.url = url + self.preload_models = preload_models or [] tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -52,6 +52,23 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): "Ollama Server is not running, start it using `ollama serve` in a separate terminal" ) from e + # Pre-download models + for model in self.preload_models: + ollama_model = self.map_to_provider_model(model) + res = await self.client.ps() + need_model_pull = True + for r in res["models"]: + if ollama_model == r["model"]: + need_model_pull = False + break + + if need_model_pull: + print(f"Pulling model: {ollama_model}") + status = await self.client.pull(ollama_model) + assert ( + status["status"] == "success" + ), f"Failed to pull model {ollama_model} in ollama" + async def shutdown(self) -> None: pass