From 5ae5a4da8aa3ec8bae89a5ff0fa4aa7754a2e350 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 25 Jul 2025 14:01:21 -0700 Subject: [PATCH] more fixes --- llama_stack/distribution/stack.py | 16 ++++++++++++---- llama_stack/providers/datatypes.py | 3 +++ .../providers/remote/inference/ollama/ollama.py | 13 ++++++------- .../providers/utils/inference/model_registry.py | 10 +++++----- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 4b12cafcc..811e188f9 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -318,8 +318,10 @@ async def construct_stack( await register_resources(run_config, impls) + await refresh_registry_once(impls) + global REGISTRY_REFRESH_TASK - REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls)) + REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls)) def cb(task): import traceback @@ -355,11 +357,17 @@ async def shutdown_stack(impls: dict[Api, Any]): REGISTRY_REFRESH_TASK.cancel() -async def refresh_registry(impls: dict[Api, Any]): +async def refresh_registry_once(impls: dict[Api, Any]): + logger.info("refreshing registry") routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] + for routing_table in routing_tables: + await routing_table.refresh() + + +async def refresh_registry_task(impls: dict[Api, Any]): + logger.info("starting registry refresh task") while True: - for routing_table in routing_tables: - await routing_table.refresh() + await refresh_registry_once(impls) await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index faf7ff18c..f9f463bf9 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -43,6 +43,9 @@ class ModelsProtocolPrivate(Protocol): -> Provider uses provider-model-id for inference """ + # this should be called `on_model_register` or something like that. + # the provider should _not_ be able to change the object in this + # callback async def register_model(self, model: Model) -> Model: ... async def unregister_model(self, model_id: str) -> None: ... diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index cb026bb94..098e4d324 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -420,9 +420,6 @@ class OllamaInferenceAdapter( except ValueError: pass # Ignore statically unknown model, will check live listing - if model.provider_resource_id is None: - raise ValueError("Model provider_resource_id cannot be None") - if model.model_type == ModelType.embedding: response = await self.client.list() if model.provider_resource_id not in [m.model for m in response.models]: @@ -433,9 +430,9 @@ class OllamaInferenceAdapter( # - models not currently running are run by the ollama server as needed response = await self.client.list() available_models = [m.model for m in response.models] - provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) - if provider_resource_id is None: - provider_resource_id = model.provider_resource_id + + provider_resource_id = model.provider_resource_id + assert provider_resource_id is not None # mypy if provider_resource_id not in available_models: available_models_latest = [m.model.split(":latest")[0] for m in response.models] if provider_resource_id in available_models_latest: @@ -443,7 +440,9 @@ class OllamaInferenceAdapter( f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" ) return model - raise UnsupportedModelError(model.provider_resource_id, available_models) + raise UnsupportedModelError(provider_resource_id, available_models) + + # mutating this should be considered an anti-pattern model.provider_resource_id = provider_resource_id return model diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index a79e4b6ae..ddb3bda8c 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -188,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate): return model async def unregister_model(self, model_id: str) -> None: - # TODO: should we block unregistering base supported provider model IDs? - if model_id not in self.alias_to_provider_id_map: - raise ValueError(f"Model id '{model_id}' is not registered.") - - del self.alias_to_provider_id_map[model_id] + # model_id is the identifier, not the provider_resource_id + # unfortunately, this ID can be of the form provider_id/model_id which + # we never registered. TODO: fix this by significantly rewriting + # registration and registry helper + pass