refactor: rename trace_protocol marker to mark_as_traced

Rename the marker decorator in apis/common/tracing.py from trace_protocol
to mark_as_traced to disambiguate it from the actual tracing implementation
decorator in core/telemetry/trace_protocol.py.

Changes:
- Rename decorator: trace_protocol -> mark_as_traced
- Rename attribute: __trace_protocol__ -> __marked_for_tracing__
- Update all API protocol files to use new decorator name
- Update router logic to check for new attribute name

This makes it clear that the marker decorator is metadata-only and doesn't
perform actual tracing, while the core decorator does the implementation.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-11-05 16:21:48 -05:00
parent cd17c62ec4
commit 29f93a6391
11 changed files with 28 additions and 26 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
def trace_protocol(cls):
def mark_as_traced(cls):
"""
Mark a protocol for automatic tracing when telemetry is enabled.
@ -14,9 +14,9 @@ def trace_protocol(cls):
Usage:
@runtime_checkable
@trace_protocol
@mark_as_traced
class MyProtocol(Protocol):
...
"""
cls.__trace_protocol__ = True
cls.__marked_for_tracing__ = True
return cls

View file

@ -20,7 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -183,7 +183,7 @@ class ConversationItemDeletedResource(BaseModel):
@runtime_checkable
@trace_protocol
@mark_as_traced
class Conversations(Protocol):
"""Conversations

View file

@ -11,7 +11,7 @@ from fastapi import File, Form, Response, UploadFile
from pydantic import BaseModel, Field
from llama_stack.apis.common.responses import Order
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
@ -102,7 +102,7 @@ class OpenAIFileDeleteResponse(BaseModel):
@runtime_checkable
@trace_protocol
@mark_as_traced
class Files(Protocol):
"""Files

View file

@ -20,7 +20,7 @@ from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import MetricResponseMixin, Order
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.models import Model
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.models.llama.datatypes import (
@ -1159,7 +1159,7 @@ class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
@runtime_checkable
@trace_protocol
@mark_as_traced
class InferenceProvider(Protocol):
"""
This protocol defines the interface that should be implemented by all inference providers.

View file

@ -9,7 +9,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
@ -105,7 +105,7 @@ class OpenAIListModelsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@mark_as_traced
class Models(Protocol):
async def list_models(self) -> ListModelsResponse:
"""List all models.

View file

@ -10,7 +10,7 @@ from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
@ -92,7 +92,7 @@ class ListPromptsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@mark_as_traced
class Prompts(Protocol):
"""Prompts

View file

@ -9,7 +9,7 @@ from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.shields import Shield
from llama_stack.apis.version import LLAMA_STACK_API_V1
@ -94,7 +94,7 @@ class ShieldStore(Protocol):
@runtime_checkable
@trace_protocol
@mark_as_traced
class Safety(Protocol):
"""Safety

View file

@ -8,7 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
@ -48,7 +48,7 @@ class ListShieldsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@mark_as_traced
class Shields(Protocol):
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
async def list_shields(self) -> ListShieldsResponse:

View file

@ -11,7 +11,7 @@ from pydantic import BaseModel
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
@ -107,7 +107,7 @@ class ListToolDefsResponse(BaseModel):
@runtime_checkable
@trace_protocol
@mark_as_traced
class ToolGroups(Protocol):
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
async def register_tool_group(
@ -189,7 +189,7 @@ class SpecialToolGroup(Enum):
@runtime_checkable
@trace_protocol
@mark_as_traced
class ToolRuntime(Protocol):
tool_store: ToolStore | None = None

View file

@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
from pydantic import BaseModel, Field
from llama_stack.apis.common.tracing import trace_protocol
from llama_stack.apis.common.tracing import mark_as_traced
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1
@ -502,7 +502,7 @@ class VectorStoreTable(Protocol):
@runtime_checkable
@trace_protocol
@mark_as_traced
class VectorIO(Protocol):
vector_store_table: VectorStoreTable | None = None

View file

@ -46,9 +46,11 @@ async def get_routing_table_impl(
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
# Apply tracing to routing table if any base class has __trace_protocol__ marker
# Apply tracing to routing table if any base class has __marked_for_tracing__ marker
# (Tracing will be no-op if telemetry is disabled)
traced_classes = [base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False)]
traced_classes = [
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
]
if traced_classes:
from llama_stack.core.telemetry.trace_protocol import trace_protocol
@ -104,10 +106,10 @@ async def get_auto_router_impl(
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
# Apply tracing to router implementation BEFORE initialize() if telemetry is enabled
# Apply to all classes in MRO that have __trace_protocol__ marker to ensure inherited methods are wrapped
# Apply to all classes in MRO that have __marked_for_tracing__ marker to ensure inherited methods are wrapped
if run_config.telemetry.enabled:
traced_classes = [
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False)
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
]
if traced_classes:
from llama_stack.core.telemetry.trace_protocol import trace_protocol