diff --git a/src/llama_stack/apis/common/responses.py b/src/llama_stack/apis/common/responses.py index 616bee73a..53a290eea 100644 --- a/src/llama_stack/apis/common/responses.py +++ b/src/llama_stack/apis/common/responses.py @@ -34,3 +34,44 @@ class PaginatedResponse(BaseModel): data: list[dict[str, Any]] has_more: bool url: str | None = None + + +# This is a short term solution to allow inference API to return metrics +# The ideal way to do this is to have a way for all response types to include metrics +# and all metric events logged to the telemetry API to be included with the response +# To do this, we will need to augment all response types with a metrics field. +# We have hit a blocker from stainless SDK that prevents us from doing this. +# The blocker is that if we were to augment the response types that have a data field +# in them like so +# class ListModelsResponse(BaseModel): +# metrics: Optional[List[MetricEvent]] = None +# data: List[Models] +# ... +# The client SDK will need to access the data by using a .data field, which is not +# ergonomic. Stainless SDK does support unwrapping the response type, but it +# requires that the response type to only have a single field. + +# We will need a way in the client SDK to signal that the metrics are needed +# and if they are needed, the client SDK has to return the full response type +# without unwrapping it. + + +@json_schema_type +class MetricInResponse(BaseModel): + """A metric value included in API responses. + :param metric: The name of the metric + :param value: The numeric value of the metric + :param unit: (Optional) The unit of measurement for the metric value + """ + + metric: str + value: int | float + unit: str | None = None + + +class MetricResponseMixin(BaseModel): + """Mixin class for API responses that can include metrics. + :param metrics: (Optional) List of metrics associated with the API response + """ + + metrics: list[MetricInResponse] | None = None diff --git a/src/llama_stack/apis/common/tracing.py b/src/llama_stack/apis/common/tracing.py new file mode 100644 index 000000000..830c2945a --- /dev/null +++ b/src/llama_stack/apis/common/tracing.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +def telemetry_traceable(cls): + """ + Mark a protocol for automatic tracing when telemetry is enabled. + + This is a metadata-only decorator with no dependencies on core. + Actual tracing is applied by core routers at runtime if telemetry is enabled. + + Usage: + @runtime_checkable + @telemetry_traceable + class MyProtocol(Protocol): + ... + """ + cls.__marked_for_tracing__ = True + return cls diff --git a/src/llama_stack/apis/conversations/conversations.py b/src/llama_stack/apis/conversations/conversations.py index 6ec7e67d6..3fdd3b47e 100644 --- a/src/llama_stack/apis/conversations/conversations.py +++ b/src/llama_stack/apis/conversations/conversations.py @@ -20,8 +20,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, ) +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod Metadata = dict[str, str] @@ -157,7 +157,7 @@ class ConversationItemDeletedResource(BaseModel): @runtime_checkable -@trace_protocol +@telemetry_traceable class Conversations(Protocol): """Conversations diff --git a/src/llama_stack/apis/files/files.py b/src/llama_stack/apis/files/files.py index 657e9f500..f0ea2f892 100644 --- a/src/llama_stack/apis/files/files.py +++ b/src/llama_stack/apis/files/files.py @@ -11,8 +11,8 @@ from fastapi import File, Form, Response, UploadFile from pydantic import BaseModel, Field from llama_stack.apis.common.responses import Order +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -102,7 +102,7 @@ class OpenAIFileDeleteResponse(BaseModel): @runtime_checkable -@trace_protocol +@telemetry_traceable class Files(Protocol): """Files diff --git a/src/llama_stack/apis/inference/inference.py b/src/llama_stack/apis/inference/inference.py index f39957190..1a865ce5f 100644 --- a/src/llama_stack/apis/inference/inference.py +++ b/src/llama_stack/apis/inference/inference.py @@ -19,11 +19,10 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import Order +from llama_stack.apis.common.responses import MetricResponseMixin, Order +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.models import Model from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.core.telemetry.telemetry import MetricResponseMixin -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -1160,7 +1159,7 @@ class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"): @runtime_checkable -@trace_protocol +@telemetry_traceable class InferenceProvider(Protocol): """ This protocol defines the interface that should be implemented by all inference providers. diff --git a/src/llama_stack/apis/models/models.py b/src/llama_stack/apis/models/models.py index 552f47c30..5c976886c 100644 --- a/src/llama_stack/apis/models/models.py +++ b/src/llama_stack/apis/models/models.py @@ -9,9 +9,9 @@ from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field, field_validator +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -105,7 +105,7 @@ class OpenAIListModelsResponse(BaseModel): @runtime_checkable -@trace_protocol +@telemetry_traceable class Models(Protocol): async def list_models(self) -> ListModelsResponse: """List all models. diff --git a/src/llama_stack/apis/prompts/prompts.py b/src/llama_stack/apis/prompts/prompts.py index 4651b9294..406ae529c 100644 --- a/src/llama_stack/apis/prompts/prompts.py +++ b/src/llama_stack/apis/prompts/prompts.py @@ -10,8 +10,8 @@ from typing import Protocol, runtime_checkable from pydantic import BaseModel, Field, field_validator, model_validator +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -92,7 +92,7 @@ class ListPromptsResponse(BaseModel): @runtime_checkable -@trace_protocol +@telemetry_traceable class Prompts(Protocol): """Prompts diff --git a/src/llama_stack/apis/safety/safety.py b/src/llama_stack/apis/safety/safety.py index 97fffcff1..8872cc518 100644 --- a/src/llama_stack/apis/safety/safety.py +++ b/src/llama_stack/apis/safety/safety.py @@ -9,10 +9,10 @@ from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.shields import Shield from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -94,7 +94,7 @@ class ShieldStore(Protocol): @runtime_checkable -@trace_protocol +@telemetry_traceable class Safety(Protocol): """Safety diff --git a/src/llama_stack/apis/shields/shields.py b/src/llama_stack/apis/shields/shields.py index 565e1db15..ca4483828 100644 --- a/src/llama_stack/apis/shields/shields.py +++ b/src/llama_stack/apis/shields/shields.py @@ -8,9 +8,9 @@ from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -48,7 +48,7 @@ class ListShieldsResponse(BaseModel): @runtime_checkable -@trace_protocol +@telemetry_traceable class Shields(Protocol): @webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1) async def list_shields(self) -> ListShieldsResponse: diff --git a/src/llama_stack/apis/tools/tools.py b/src/llama_stack/apis/tools/tools.py index 29065a713..c9bdfcfb6 100644 --- a/src/llama_stack/apis/tools/tools.py +++ b/src/llama_stack/apis/tools/tools.py @@ -11,9 +11,9 @@ from pydantic import BaseModel from typing_extensions import runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -107,7 +107,7 @@ class ListToolDefsResponse(BaseModel): @runtime_checkable -@trace_protocol +@telemetry_traceable class ToolGroups(Protocol): @webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1) async def register_tool_group( @@ -189,7 +189,7 @@ class SpecialToolGroup(Enum): @runtime_checkable -@trace_protocol +@telemetry_traceable class ToolRuntime(Protocol): tool_store: ToolStore | None = None diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index 9148d10e5..26c961db3 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -13,10 +13,10 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from fastapi import Body from pydantic import BaseModel, Field +from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_stores import VectorStore from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema @@ -502,7 +502,7 @@ class VectorStoreTable(Protocol): @runtime_checkable -@trace_protocol +@telemetry_traceable class VectorIO(Protocol): vector_store_table: VectorStoreTable | None = None diff --git a/src/llama_stack/core/resolver.py b/src/llama_stack/core/resolver.py index 805d260fc..8bf371fed 100644 --- a/src/llama_stack/core/resolver.py +++ b/src/llama_stack/core/resolver.py @@ -397,6 +397,18 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config + # Apply tracing if telemetry is enabled and any base class has __marked_for_tracing__ marker + if run_config.telemetry.enabled: + traced_classes = [ + base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False) + ] + + if traced_classes: + from llama_stack.core.telemetry.trace_protocol import trace_protocol + + for cls in traced_classes: + trace_protocol(cls) + protocols = api_protocol_map_for_compliance_check(run_config) additional_protocols = additional_protocols_map() # TODO: check compliance for special tool groups diff --git a/src/llama_stack/core/routers/__init__.py b/src/llama_stack/core/routers/__init__.py index 204cbb87f..729d1c9ea 100644 --- a/src/llama_stack/core/routers/__init__.py +++ b/src/llama_stack/core/routers/__init__.py @@ -45,6 +45,7 @@ async def get_routing_table_impl( raise ValueError(f"API {api.value} not found in router map") impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy) + await impl.initialize() return impl @@ -92,5 +93,6 @@ async def get_auto_router_impl( api_to_dep_impl["safety_config"] = run_config.safety impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) + await impl.initialize() return impl diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index 9476c961a..459c1aa1a 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -163,47 +163,6 @@ class MetricEvent(EventCommon): unit: str -@json_schema_type -class MetricInResponse(BaseModel): - """A metric value included in API responses. - :param metric: The name of the metric - :param value: The numeric value of the metric - :param unit: (Optional) The unit of measurement for the metric value - """ - - metric: str - value: int | float - unit: str | None = None - - -# This is a short term solution to allow inference API to return metrics -# The ideal way to do this is to have a way for all response types to include metrics -# and all metric events logged to the telemetry API to be included with the response -# To do this, we will need to augment all response types with a metrics field. -# We have hit a blocker from stainless SDK that prevents us from doing this. -# The blocker is that if we were to augment the response types that have a data field -# in them like so -# class ListModelsResponse(BaseModel): -# metrics: Optional[List[MetricEvent]] = None -# data: List[Models] -# ... -# The client SDK will need to access the data by using a .data field, which is not -# ergonomic. Stainless SDK does support unwrapping the response type, but it -# requires that the response type to only have a single field. - -# We will need a way in the client SDK to signal that the metrics are needed -# and if they are needed, the client SDK has to return the full response type -# without unwrapping it. - - -class MetricResponseMixin(BaseModel): - """Mixin class for API responses that can include metrics. - :param metrics: (Optional) List of metrics associated with the API response - """ - - metrics: list[MetricInResponse] | None = None - - @json_schema_type class StructuredLogType(Enum): """The type of structured log event payload. diff --git a/src/llama_stack/core/telemetry/trace_protocol.py b/src/llama_stack/core/telemetry/trace_protocol.py index 807b8e2a9..95b33a4bc 100644 --- a/src/llama_stack/core/telemetry/trace_protocol.py +++ b/src/llama_stack/core/telemetry/trace_protocol.py @@ -129,6 +129,15 @@ def trace_protocol[T: type[Any]](cls: T) -> T: else: return sync_wrapper + # Wrap methods on the class itself (for classes applied at runtime) + # Skip if already wrapped (indicated by __wrapped__ attribute) + for name, method in vars(cls).items(): + if inspect.isfunction(method) and not name.startswith("_"): + if not hasattr(method, "__wrapped__"): + wrapped = trace_method(method) + setattr(cls, name, wrapped) # noqa: B010 + + # Also set up __init_subclass__ for future subclasses original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None)) def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807