mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: move trace_protocol to instantiate_provider
since impl is define in instantiate_provider, makes more sense to apply trace_protocol there Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
d00a085aed
commit
d1adc5a6eb
2 changed files with 12 additions and 23 deletions
|
|
@ -397,6 +397,18 @@ async def instantiate_provider(
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
|
# Apply tracing if telemetry is enabled and any base class has __marked_for_tracing__ marker
|
||||||
|
if run_config.telemetry.enabled:
|
||||||
|
traced_classes = [
|
||||||
|
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
|
||||||
|
]
|
||||||
|
|
||||||
|
if traced_classes:
|
||||||
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
for cls in traced_classes:
|
||||||
|
trace_protocol(cls)
|
||||||
|
|
||||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
|
|
|
||||||
|
|
@ -46,17 +46,6 @@ async def get_routing_table_impl(
|
||||||
|
|
||||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||||
|
|
||||||
# Apply tracing to routing table if any base class has __marked_for_tracing__ marker
|
|
||||||
# (Tracing will be no-op if telemetry is disabled)
|
|
||||||
traced_classes = [
|
|
||||||
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
|
|
||||||
]
|
|
||||||
if traced_classes:
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
|
|
||||||
for cls in traced_classes:
|
|
||||||
trace_protocol(cls)
|
|
||||||
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
@ -105,17 +94,5 @@ async def get_auto_router_impl(
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
|
|
||||||
# Apply tracing to router implementation BEFORE initialize() if telemetry is enabled
|
|
||||||
# Apply to all classes in MRO that have __marked_for_tracing__ marker to ensure inherited methods are wrapped
|
|
||||||
if run_config.telemetry.enabled:
|
|
||||||
traced_classes = [
|
|
||||||
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
|
|
||||||
]
|
|
||||||
if traced_classes:
|
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
|
||||||
|
|
||||||
for cls in traced_classes:
|
|
||||||
trace_protocol(cls)
|
|
||||||
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue