diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index ae7d9ab40..8247eb594 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -156,6 +156,11 @@ async def instantiate_provider( assert isinstance(provider_config, GenericProviderConfig) config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider_config.config) + + if hasattr(provider_config, "routing_key"): + routing_key = provider_config.routing_key + deps["routing_key"] = routing_key + args = [config, deps] elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" diff --git a/llama_stack/providers/adapters/inference/ollama/__init__.py b/llama_stack/providers/adapters/inference/ollama/__init__.py index 7763af8d1..013b6c8d3 100644 --- a/llama_stack/providers/adapters/inference/ollama/__init__.py +++ b/llama_stack/providers/adapters/inference/ollama/__init__.py @@ -15,5 +15,12 @@ async def get_adapter_impl(config: RemoteProviderConfig, _deps): from .ollama import OllamaInferenceAdapter impl = OllamaInferenceAdapter(config.url) - await impl.initialize() + + routing_key = _deps.get("routing_key") + if not routing_key: + raise ValueError( + "Routing key is required for the Ollama adapter but was not found." + ) + + await impl.initialize(routing_key) return impl diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index b4bee587e..47abb9a98 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -23,7 +23,7 @@ from llama_stack.providers.utils.inference.routable import RoutableProviderForMo # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models OLLAMA_SUPPORTED_SKUS = { - "Llama-Guard-3-1B": "xe/llamaguard3:latest", + "Llama-Guard-3-8B": "xe/llamaguard3:latest", "llama_guard": "xe/llamaguard3:latest", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", @@ -40,18 +40,18 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): self.url = url tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) - self.model = "Llama3.2-1B-Instruct" @property def client(self) -> AsyncClient: return AsyncClient(host=self.url) - async def initialize(self) -> None: + async def initialize(self, routing_key: str) -> None: print("Initializing Ollama, checking connectivity to server...") try: await self.client.ps() - print(f"Connected to Ollama server. Pre-downloading {self.model}...") - await self.predownload_models(ollama_model=self.model) + ollama_model = self.map_to_provider_model(routing_key) + print(f"Connected to Ollama server. Pre-downloading {ollama_model}...") + await self.predownload_models(ollama_model) except httpx.ConnectError as e: raise RuntimeError( "Ollama Server is not running, start it using `ollama serve` in a separate terminal" @@ -140,8 +140,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels): options = self.get_ollama_chat_options(request) ollama_model = self.map_to_provider_model(request.model) - if ollama_model != self.model: - self.predownload_models(ollama_model) + self.predownload_models(ollama_model) if not request.stream: r = await self.client.chat(