diff --git a/src/llama_stack/core/resolver.py b/src/llama_stack/core/resolver.py index 805d260fc..8bf371fed 100644 --- a/src/llama_stack/core/resolver.py +++ b/src/llama_stack/core/resolver.py @@ -397,6 +397,18 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config + # Apply tracing if telemetry is enabled and any base class has __marked_for_tracing__ marker + if run_config.telemetry.enabled: + traced_classes = [ + base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False) + ] + + if traced_classes: + from llama_stack.core.telemetry.trace_protocol import trace_protocol + + for cls in traced_classes: + trace_protocol(cls) + protocols = api_protocol_map_for_compliance_check(run_config) additional_protocols = additional_protocols_map() # TODO: check compliance for special tool groups diff --git a/src/llama_stack/core/routers/__init__.py b/src/llama_stack/core/routers/__init__.py index 81944dae0..729d1c9ea 100644 --- a/src/llama_stack/core/routers/__init__.py +++ b/src/llama_stack/core/routers/__init__.py @@ -46,17 +46,6 @@ async def get_routing_table_impl( impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy) - # Apply tracing to routing table if any base class has __marked_for_tracing__ marker - # (Tracing will be no-op if telemetry is disabled) - traced_classes = [ - base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False) - ] - if traced_classes: - from llama_stack.core.telemetry.trace_protocol import trace_protocol - - for cls in traced_classes: - trace_protocol(cls) - await impl.initialize() return impl @@ -105,17 +94,5 @@ async def get_auto_router_impl( impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) - # Apply tracing to router implementation BEFORE initialize() if telemetry is enabled - # Apply to all classes in MRO that have __marked_for_tracing__ marker to ensure inherited methods are wrapped - if run_config.telemetry.enabled: - traced_classes = [ - base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False) - ] - if traced_classes: - from llama_stack.core.telemetry.trace_protocol import trace_protocol - - for cls in traced_classes: - trace_protocol(cls) - await impl.initialize() return impl