make refreshing happen for all routing tables, naming changes, ollama fixes

This commit is contained in:
Ashwin Bharambe 2025-07-24 10:24:10 -07:00
parent 487e073378
commit 0fe110d94a
6 changed files with 67 additions and 63 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib.resources
import os
import re
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -90,6 +92,9 @@ RESOURCES = [
]
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -324,9 +329,33 @@ async def construct_stack(
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
task = asyncio.create_task(refresh_registry(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
task.add_done_callback(cb)
return impls
async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True:
for routing_table in routing_tables:
await routing_table.refresh()
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"