Fix issue #183: Pre-download models during server initialization to prevent HTTP timeouts

This commit moves the model downloading logic from the `chat_completion` method to the `initialize` method in `OllamaInferenceAdapter`. By pre-loading required models during server startup, we ensure that large models (e.g., 16GB) are downloaded before serving requests, thus preventing HTTP request timeouts and aborted downloads during the first inference request.

Closes #183.
This commit is contained in:
Ezreal 2024-10-08 01:21:31 +08:00 committed by GitHub
parent de80f66470
commit 3e3b096071
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,7 +9,6 @@ from typing import AsyncGenerator
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
@ -31,11 +30,12 @@ OLLAMA_SUPPORTED_SKUS = {
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, url: str) -> None:
def __init__(self, url: str, preload_models: List[str] = None) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
)
self.url = url
self.preload_models = preload_models or []
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@ -52,6 +52,23 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
# Pre-download models
for model in self.preload_models:
ollama_model = self.map_to_provider_model(model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {ollama_model} in ollama"
async def shutdown(self) -> None:
pass