diff --git a/src/llama_stack/apis/common/tracing.py b/src/llama_stack/apis/common/tracing.py index f82eadeb0..98cf4ede6 100644 --- a/src/llama_stack/apis/common/tracing.py +++ b/src/llama_stack/apis/common/tracing.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -def trace_protocol(cls): +def mark_as_traced(cls): """ Mark a protocol for automatic tracing when telemetry is enabled. @@ -14,9 +14,9 @@ def trace_protocol(cls): Usage: @runtime_checkable - @trace_protocol + @mark_as_traced class MyProtocol(Protocol): ... """ - cls.__trace_protocol__ = True + cls.__marked_for_tracing__ = True return cls diff --git a/src/llama_stack/apis/conversations/conversations.py b/src/llama_stack/apis/conversations/conversations.py index 85119177e..e7f9b78be 100644 --- a/src/llama_stack/apis/conversations/conversations.py +++ b/src/llama_stack/apis/conversations/conversations.py @@ -20,7 +20,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, ) -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -183,7 +183,7 @@ class ConversationItemDeletedResource(BaseModel): @runtime_checkable -@trace_protocol +@mark_as_traced class Conversations(Protocol): """Conversations diff --git a/src/llama_stack/apis/files/files.py b/src/llama_stack/apis/files/files.py index fe9de7ad4..7ed9235ff 100644 --- a/src/llama_stack/apis/files/files.py +++ b/src/llama_stack/apis/files/files.py @@ -11,7 +11,7 @@ from fastapi import File, Form, Response, UploadFile from pydantic import BaseModel, Field from llama_stack.apis.common.responses import Order -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.schema_utils import json_schema_type, webmethod @@ -102,7 +102,7 @@ class OpenAIFileDeleteResponse(BaseModel): @runtime_checkable -@trace_protocol +@mark_as_traced class Files(Protocol): """Files diff --git a/src/llama_stack/apis/inference/inference.py b/src/llama_stack/apis/inference/inference.py index 21dc3efc1..e83f8d82a 100644 --- a/src/llama_stack/apis/inference/inference.py +++ b/src/llama_stack/apis/inference/inference.py @@ -20,7 +20,7 @@ from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.responses import MetricResponseMixin, Order -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.models import Model from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA from llama_stack.models.llama.datatypes import ( @@ -1159,7 +1159,7 @@ class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"): @runtime_checkable -@trace_protocol +@mark_as_traced class InferenceProvider(Protocol): """ This protocol defines the interface that should be implemented by all inference providers. diff --git a/src/llama_stack/apis/models/models.py b/src/llama_stack/apis/models/models.py index bb451dd55..c7fb08488 100644 --- a/src/llama_stack/apis/models/models.py +++ b/src/llama_stack/apis/models/models.py @@ -9,7 +9,7 @@ from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field, field_validator -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.schema_utils import json_schema_type, webmethod @@ -105,7 +105,7 @@ class OpenAIListModelsResponse(BaseModel): @runtime_checkable -@trace_protocol +@mark_as_traced class Models(Protocol): async def list_models(self) -> ListModelsResponse: """List all models. diff --git a/src/llama_stack/apis/prompts/prompts.py b/src/llama_stack/apis/prompts/prompts.py index 3c5adb81f..0e808aa5f 100644 --- a/src/llama_stack/apis/prompts/prompts.py +++ b/src/llama_stack/apis/prompts/prompts.py @@ -10,7 +10,7 @@ from typing import Protocol, runtime_checkable from pydantic import BaseModel, Field, field_validator, model_validator -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.schema_utils import json_schema_type, webmethod @@ -92,7 +92,7 @@ class ListPromptsResponse(BaseModel): @runtime_checkable -@trace_protocol +@mark_as_traced class Prompts(Protocol): """Prompts diff --git a/src/llama_stack/apis/safety/safety.py b/src/llama_stack/apis/safety/safety.py index c800fb22b..a40257538 100644 --- a/src/llama_stack/apis/safety/safety.py +++ b/src/llama_stack/apis/safety/safety.py @@ -9,7 +9,7 @@ from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.shields import Shield from llama_stack.apis.version import LLAMA_STACK_API_V1 @@ -94,7 +94,7 @@ class ShieldStore(Protocol): @runtime_checkable -@trace_protocol +@mark_as_traced class Safety(Protocol): """Safety diff --git a/src/llama_stack/apis/shields/shields.py b/src/llama_stack/apis/shields/shields.py index 66cd297f5..0b0437bbb 100644 --- a/src/llama_stack/apis/shields/shields.py +++ b/src/llama_stack/apis/shields/shields.py @@ -8,7 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.schema_utils import json_schema_type, webmethod @@ -48,7 +48,7 @@ class ListShieldsResponse(BaseModel): @runtime_checkable -@trace_protocol +@mark_as_traced class Shields(Protocol): @webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1) async def list_shields(self) -> ListShieldsResponse: diff --git a/src/llama_stack/apis/tools/tools.py b/src/llama_stack/apis/tools/tools.py index 3ca0fa4b4..59fde1439 100644 --- a/src/llama_stack/apis/tools/tools.py +++ b/src/llama_stack/apis/tools/tools.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from typing_extensions import runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.schema_utils import json_schema_type, webmethod @@ -107,7 +107,7 @@ class ListToolDefsResponse(BaseModel): @runtime_checkable -@trace_protocol +@mark_as_traced class ToolGroups(Protocol): @webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1) async def register_tool_group( @@ -189,7 +189,7 @@ class SpecialToolGroup(Enum): @runtime_checkable -@trace_protocol +@mark_as_traced class ToolRuntime(Protocol): tool_store: ToolStore | None = None diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index 50dd8a785..f39212733 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from fastapi import Body from pydantic import BaseModel, Field -from llama_stack.apis.common.tracing import trace_protocol +from llama_stack.apis.common.tracing import mark_as_traced from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_stores import VectorStore from llama_stack.apis.version import LLAMA_STACK_API_V1 @@ -502,7 +502,7 @@ class VectorStoreTable(Protocol): @runtime_checkable -@trace_protocol +@mark_as_traced class VectorIO(Protocol): vector_store_table: VectorStoreTable | None = None diff --git a/src/llama_stack/core/routers/__init__.py b/src/llama_stack/core/routers/__init__.py index 8f285b107..81944dae0 100644 --- a/src/llama_stack/core/routers/__init__.py +++ b/src/llama_stack/core/routers/__init__.py @@ -46,9 +46,11 @@ async def get_routing_table_impl( impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy) - # Apply tracing to routing table if any base class has __trace_protocol__ marker + # Apply tracing to routing table if any base class has __marked_for_tracing__ marker # (Tracing will be no-op if telemetry is disabled) - traced_classes = [base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False)] + traced_classes = [ + base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False) + ] if traced_classes: from llama_stack.core.telemetry.trace_protocol import trace_protocol @@ -104,10 +106,10 @@ async def get_auto_router_impl( impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) # Apply tracing to router implementation BEFORE initialize() if telemetry is enabled - # Apply to all classes in MRO that have __trace_protocol__ marker to ensure inherited methods are wrapped + # Apply to all classes in MRO that have __marked_for_tracing__ marker to ensure inherited methods are wrapped if run_config.telemetry.enabled: traced_classes = [ - base for base in reversed(impl.__class__.__mro__) if getattr(base, "__trace_protocol__", False) + base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False) ] if traced_classes: from llama_stack.core.telemetry.trace_protocol import trace_protocol