mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
fix: apply tracing to routing table properly and MRO
the routing table hierarchy needs to properly recieve trace_protocol. since we are not applying tracing a little later in the process, this needs some special handling Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
53da6bf3d8
commit
cd17c62ec4
2 changed files with 32 additions and 7 deletions
|
|
@ -45,6 +45,16 @@ async def get_routing_table_impl(
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||||
|
|
||||||
|
# Apply tracing to routing table if any base class has __trace_protocol__ marker
|
||||||
|
# (Tracing will be no-op if telemetry is disabled)
|
||||||
|
traced_classes = [base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False)]
|
||||||
|
if traced_classes:
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
for cls in traced_classes:
|
||||||
|
trace_protocol(cls)
|
||||||
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
@ -92,12 +102,18 @@ async def get_auto_router_impl(
|
||||||
api_to_dep_impl["safety_config"] = run_config.safety
|
api_to_dep_impl["safety_config"] = run_config.safety
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
|
|
||||||
|
# Apply tracing to router implementation BEFORE initialize() if telemetry is enabled
|
||||||
|
# Apply to all classes in MRO that have __trace_protocol__ marker to ensure inherited methods are wrapped
|
||||||
|
if run_config.telemetry.enabled:
|
||||||
|
traced_classes = [
|
||||||
|
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False)
|
||||||
|
]
|
||||||
|
if traced_classes:
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
for cls in traced_classes:
|
||||||
|
trace_protocol(cls)
|
||||||
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
|
||||||
# Apply tracing to router implementation if telemetry is enabled and protocol wants tracing
|
|
||||||
if run_config.telemetry.enabled and getattr(impl.__class__, "__trace_protocol__", False):
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
|
|
||||||
trace_protocol(impl.__class__)
|
|
||||||
|
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,15 @@ def trace_protocol[T: type[Any]](cls: T) -> T:
|
||||||
else:
|
else:
|
||||||
return sync_wrapper
|
return sync_wrapper
|
||||||
|
|
||||||
|
# Wrap methods on the class itself (for classes applied at runtime)
|
||||||
|
# Skip if already wrapped (indicated by __wrapped__ attribute)
|
||||||
|
for name, method in vars(cls).items():
|
||||||
|
if inspect.isfunction(method) and not name.startswith("_"):
|
||||||
|
if not hasattr(method, "__wrapped__"):
|
||||||
|
wrapped = trace_method(method)
|
||||||
|
setattr(cls, name, wrapped) # noqa: B010
|
||||||
|
|
||||||
|
# Also set up __init_subclass__ for future subclasses
|
||||||
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
|
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
|
||||||
|
|
||||||
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
|
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue