diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index bd267a5f8..b4bee587e 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -23,6 +23,8 @@ from llama_stack.providers.utils.inference.routable import RoutableProviderForMo # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models OLLAMA_SUPPORTED_SKUS = { + "Llama-Guard-3-1B": "xe/llamaguard3:latest", + "llama_guard": "xe/llamaguard3:latest", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", @@ -38,6 +40,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): self.url = url tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) + self.model = "Llama3.2-1B-Instruct" @property def client(self) -> AsyncClient: @@ -47,6 +50,8 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): print("Initializing Ollama, checking connectivity to server...") try: await self.client.ps() + print(f"Connected to Ollama server. Pre-downloading {self.model}...") + await self.predownload_models(ollama_model=self.model) except httpx.ConnectError as e: raise RuntimeError( "Ollama Server is not running, start it using `ollama serve` in a separate terminal" @@ -90,6 +95,24 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): return options + async def predownload_models(self, ollama_model: str): + res = await self.client.ps() + + need_model_pull = True + for r in res["models"]: + if ollama_model == r["model"]: + need_model_pull = False + break + + if need_model_pull: + print(f"Pulling model: {ollama_model}") + status = await self.client.pull(ollama_model) + assert ( + status["status"] == "success" + ), f"Failed to pull model {self.model} in ollama" + else: + print(f"Model {ollama_model} is already available") + async def chat_completion( self, model: str, @@ -117,19 +140,8 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): options = self.get_ollama_chat_options(request) ollama_model = self.map_to_provider_model(request.model) - res = await self.client.ps() - need_model_pull = True - for r in res["models"]: - if ollama_model == r["model"]: - need_model_pull = False - break - - if need_model_pull: - print(f"Pulling model: {ollama_model}") - status = await self.client.pull(ollama_model) - assert ( - status["status"] == "success" - ), f"Failed to pull model {self.model} in ollama" + if ollama_model != self.model: + self.predownload_models(ollama_model) if not request.stream: r = await self.client.chat(