From 3e3b096071296c5c8ce4c96abf323d0ca907a717 Mon Sep 17 00:00:00 2001 From: Ezreal Date: Tue, 8 Oct 2024 01:21:31 +0800 Subject: [PATCH] Fix issue #183 Fix issue #183: Pre-download models during server initialization to prevent HTTP timeouts This commit moves the model downloading logic from the `chat_completion` method to the `initialize` method in `OllamaInferenceAdapter`. By pre-loading required models during server startup, we ensure that large models (e.g., 16GB) are downloaded before serving requests, thus preventing HTTP request timeouts and aborted downloads during the first inference request. Closes #183. --- .../adapters/inference/ollama/ollama.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index bd267a5f8..402417ed3 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -9,7 +9,6 @@ from typing import AsyncGenerator import httpx from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -31,11 +30,12 @@ OLLAMA_SUPPORTED_SKUS = { class OllamaInferenceAdapter(Inference, RoutableProviderForModels): - def __init__(self, url: str) -> None: + def __init__(self, url: str, preload_models: List[str] = None) -> None: RoutableProviderForModels.__init__( self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS ) self.url = url + self.preload_models = preload_models or [] tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @@ -52,6 +52,23 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): "Ollama Server is not running, start it using `ollama serve` in a separate terminal" ) from e + # Pre-download models + for model in self.preload_models: + ollama_model = self.map_to_provider_model(model) + res = await self.client.ps() + need_model_pull = True + for r in res["models"]: + if ollama_model == r["model"]: + need_model_pull = False + break + + if need_model_pull: + print(f"Pulling model: {ollama_model}") + status = await self.client.pull(ollama_model) + assert ( + status["status"] == "success" + ), f"Failed to pull model {ollama_model} in ollama" + async def shutdown(self) -> None: pass