mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fix issue #183
Fix issue #183: Pre-download models during server initialization to prevent HTTP timeouts This commit moves the model downloading logic from the `chat_completion` method to the `initialize` method in `OllamaInferenceAdapter`. By pre-loading required models during server startup, we ensure that large models (e.g., 16GB) are downloaded before serving requests, thus preventing HTTP request timeouts and aborted downloads during the first inference request. Closes #183.
This commit is contained in:
parent
de80f66470
commit
3e3b096071
1 changed files with 19 additions and 2 deletions
|
@ -9,7 +9,6 @@ from typing import AsyncGenerator
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
@ -31,11 +30,12 @@ OLLAMA_SUPPORTED_SKUS = {
|
|||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
def __init__(self, url: str) -> None:
|
||||
def __init__(self, url: str, preload_models: List[str] = None) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
||||
)
|
||||
self.url = url
|
||||
self.preload_models = preload_models or []
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
|
@ -52,6 +52,23 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
) from e
|
||||
|
||||
# Pre-download models
|
||||
for model in self.preload_models:
|
||||
ollama_model = self.map_to_provider_model(model)
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {ollama_model} in ollama"
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue