Ensure models are downloaded before serving in Ollama inference

This commit is contained in:
Frieda (Jingying) Huang 2024-10-06 12:09:22 -04:00
parent 0edf24b227
commit 969a11fb8a
2 changed files with 8 additions and 9 deletions

View file

@ -15,12 +15,6 @@ async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url) impl = OllamaInferenceAdapter(config.url)
impl._deps = _deps
routing_key = _deps.get("routing_key") await impl.initialize()
if not routing_key:
raise ValueError(
"Routing key is required for the Ollama adapter but was not found."
)
await impl.initialize(routing_key)
return impl return impl

View file

@ -45,10 +45,15 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
return AsyncClient(host=self.url) return AsyncClient(host=self.url)
async def initialize(self, routing_key: str) -> None: async def initialize(self) -> None:
print("Initializing Ollama, checking connectivity to server...") print("Initializing Ollama, checking connectivity to server...")
try: try:
await self.client.ps() await self.client.ps()
routing_key = self._deps.get("routing_key")
if not routing_key:
raise ValueError(
"Routing key is required for the Ollama adapter but was not found."
)
ollama_model = self.map_to_provider_model(routing_key) ollama_model = self.map_to_provider_model(routing_key)
print(f"Connected to Ollama server. Pre-downloading {ollama_model}...") print(f"Connected to Ollama server. Pre-downloading {ollama_model}...")
await self.predownload_models(ollama_model) await self.predownload_models(ollama_model)