diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 437db0176..1454bf45f 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -22,10 +22,12 @@ logger = get_logger(name=__name__, category="core") class ModelsRoutingTable(CommonRoutingTableImpl, Models): listed_providers: set[str] = set() model_refresh_interval_seconds: int = 300 + _refresh_task: asyncio.Task | None = None async def initialize(self) -> None: await super().initialize() task = asyncio.create_task(self._refresh_models()) + self._refresh_task = task def cb(task): import traceback @@ -40,6 +42,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): task.add_done_callback(cb) + async def shutdown(self) -> None: + await super().shutdown() + if self._refresh_task: + self._refresh_task.cancel() + self._refresh_task = None + async def _refresh_models(self) -> None: while True: for provider_id, provider in self.impls_by_provider_id.items():