diff --git a/llama_stack/providers/remote/inference/ollama/__init__.py b/llama_stack/providers/remote/inference/ollama/__init__.py index 9f4adc75f..491339451 100644 --- a/llama_stack/providers/remote/inference/ollama/__init__.py +++ b/llama_stack/providers/remote/inference/ollama/__init__.py @@ -10,6 +10,6 @@ from .config import OllamaImplConfig async def get_adapter_impl(config: OllamaImplConfig, _deps): from .ollama import OllamaInferenceAdapter - impl = OllamaInferenceAdapter(config.url, raise_on_connect_error=config.raise_on_connect_error) + impl = OllamaInferenceAdapter(config) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index d81d21dac..2f51920b5 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -56,6 +56,7 @@ from llama_stack.providers.datatypes import ( HealthStatus, ModelsProtocolPrivate, ) +from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -89,10 +90,10 @@ class OllamaInferenceAdapter( InferenceProvider, ModelsProtocolPrivate, ): - def __init__(self, url: str, raise_on_connect_error: bool = True) -> None: + def __init__(self, config: OllamaImplConfig) -> None: self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) - self.url = url - self.raise_on_connect_error = raise_on_connect_error + self.url = config.url + self.raise_on_connect_error = config.raise_on_connect_error @property def client(self) -> AsyncClient: