From cd17c62ec4a2646d25081af8252b2484cea0bfe7 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 5 Nov 2025 09:52:34 -0500 Subject: [PATCH] fix: apply tracing to routing table properly and MRO the routing table hierarchy needs to properly recieve trace_protocol. since we are not applying tracing a little later in the process, this needs some special handling Signed-off-by: Charlie Doern --- src/llama_stack/core/routers/__init__.py | 30 ++++++++++++++----- .../core/telemetry/trace_protocol.py | 9 ++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/llama_stack/core/routers/__init__.py b/src/llama_stack/core/routers/__init__.py index ccc27a963..8f285b107 100644 --- a/src/llama_stack/core/routers/__init__.py +++ b/src/llama_stack/core/routers/__init__.py @@ -45,6 +45,16 @@ async def get_routing_table_impl( raise ValueError(f"API {api.value} not found in router map") impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy) + + # Apply tracing to routing table if any base class has __trace_protocol__ marker + # (Tracing will be no-op if telemetry is disabled) + traced_classes = [base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False)] + if traced_classes: + from llama_stack.core.telemetry.trace_protocol import trace_protocol + + for cls in traced_classes: + trace_protocol(cls) + await impl.initialize() return impl @@ -92,12 +102,18 @@ async def get_auto_router_impl( api_to_dep_impl["safety_config"] = run_config.safety impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) + + # Apply tracing to router implementation BEFORE initialize() if telemetry is enabled + # Apply to all classes in MRO that have __trace_protocol__ marker to ensure inherited methods are wrapped + if run_config.telemetry.enabled: + traced_classes = [ + base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False) + ] + if traced_classes: + from llama_stack.core.telemetry.trace_protocol import trace_protocol + + for cls in traced_classes: + trace_protocol(cls) + await impl.initialize() - - # Apply tracing to router implementation if telemetry is enabled and protocol wants tracing - if run_config.telemetry.enabled and getattr(impl.__class__, "__trace_protocol__", False): - from llama_stack.core.telemetry.trace_protocol import trace_protocol - - trace_protocol(impl.__class__) - return impl diff --git a/src/llama_stack/core/telemetry/trace_protocol.py b/src/llama_stack/core/telemetry/trace_protocol.py index 807b8e2a9..95b33a4bc 100644 --- a/src/llama_stack/core/telemetry/trace_protocol.py +++ b/src/llama_stack/core/telemetry/trace_protocol.py @@ -129,6 +129,15 @@ def trace_protocol[T: type[Any]](cls: T) -> T: else: return sync_wrapper + # Wrap methods on the class itself (for classes applied at runtime) + # Skip if already wrapped (indicated by __wrapped__ attribute) + for name, method in vars(cls).items(): + if inspect.isfunction(method) and not name.startswith("_"): + if not hasattr(method, "__wrapped__"): + wrapped = trace_method(method) + setattr(cls, name, wrapped) # noqa: B010 + + # Also set up __init_subclass__ for future subclasses original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None)) def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807