diff --git a/src/llama_stack/apis/common/responses.py b/src/llama_stack/apis/common/responses.py index 616bee73a..53a290eea 100644 --- a/src/llama_stack/apis/common/responses.py +++ b/src/llama_stack/apis/common/responses.py @@ -34,3 +34,44 @@ class PaginatedResponse(BaseModel): data: list[dict[str, Any]] has_more: bool url: str | None = None + + +# This is a short term solution to allow inference API to return metrics +# The ideal way to do this is to have a way for all response types to include metrics +# and all metric events logged to the telemetry API to be included with the response +# To do this, we will need to augment all response types with a metrics field. +# We have hit a blocker from stainless SDK that prevents us from doing this. +# The blocker is that if we were to augment the response types that have a data field +# in them like so +# class ListModelsResponse(BaseModel): +# metrics: Optional[List[MetricEvent]] = None +# data: List[Models] +# ... +# The client SDK will need to access the data by using a .data field, which is not +# ergonomic. Stainless SDK does support unwrapping the response type, but it +# requires that the response type to only have a single field. + +# We will need a way in the client SDK to signal that the metrics are needed +# and if they are needed, the client SDK has to return the full response type +# without unwrapping it. + + +@json_schema_type +class MetricInResponse(BaseModel): + """A metric value included in API responses. + :param metric: The name of the metric + :param value: The numeric value of the metric + :param unit: (Optional) The unit of measurement for the metric value + """ + + metric: str + value: int | float + unit: str | None = None + + +class MetricResponseMixin(BaseModel): + """Mixin class for API responses that can include metrics. + :param metrics: (Optional) List of metrics associated with the API response + """ + + metrics: list[MetricInResponse] | None = None diff --git a/src/llama_stack/apis/common/tracing.py b/src/llama_stack/apis/common/tracing.py new file mode 100644 index 000000000..f82eadeb0 --- /dev/null +++ b/src/llama_stack/apis/common/tracing.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +def trace_protocol(cls): + """ + Mark a protocol for automatic tracing when telemetry is enabled. + + This is a metadata-only decorator with no dependencies on core. + Actual tracing is applied by core routers at runtime if telemetry is enabled. + + Usage: + @runtime_checkable + @trace_protocol + class MyProtocol(Protocol): + ... + """ + cls.__trace_protocol__ = True + return cls diff --git a/src/llama_stack/apis/conversations/conversations.py b/src/llama_stack/apis/conversations/conversations.py index d75683efa..85119177e 100644 --- a/src/llama_stack/apis/conversations/conversations.py +++ b/src/llama_stack/apis/conversations/conversations.py @@ -20,8 +20,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, ) +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod Metadata = dict[str, str] diff --git a/src/llama_stack/apis/files/files.py b/src/llama_stack/apis/files/files.py index 657e9f500..fe9de7ad4 100644 --- a/src/llama_stack/apis/files/files.py +++ b/src/llama_stack/apis/files/files.py @@ -11,8 +11,8 @@ from fastapi import File, Form, Response, UploadFile from pydantic import BaseModel, Field from llama_stack.apis.common.responses import Order +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/src/llama_stack/apis/inference/inference.py b/src/llama_stack/apis/inference/inference.py index f39957190..21dc3efc1 100644 --- a/src/llama_stack/apis/inference/inference.py +++ b/src/llama_stack/apis/inference/inference.py @@ -19,11 +19,10 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import Order +from llama_stack.apis.common.responses import MetricResponseMixin, Order +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.models import Model from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.core.telemetry.telemetry import MetricResponseMixin -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, diff --git a/src/llama_stack/apis/models/models.py b/src/llama_stack/apis/models/models.py index 552f47c30..bb451dd55 100644 --- a/src/llama_stack/apis/models/models.py +++ b/src/llama_stack/apis/models/models.py @@ -9,9 +9,9 @@ from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field, field_validator +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/src/llama_stack/apis/prompts/prompts.py b/src/llama_stack/apis/prompts/prompts.py index 4651b9294..3c5adb81f 100644 --- a/src/llama_stack/apis/prompts/prompts.py +++ b/src/llama_stack/apis/prompts/prompts.py @@ -10,8 +10,8 @@ from typing import Protocol, runtime_checkable from pydantic import BaseModel, Field, field_validator, model_validator +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/src/llama_stack/apis/safety/safety.py b/src/llama_stack/apis/safety/safety.py index 97fffcff1..c800fb22b 100644 --- a/src/llama_stack/apis/safety/safety.py +++ b/src/llama_stack/apis/safety/safety.py @@ -9,10 +9,10 @@ from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.shields import Shield from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/src/llama_stack/apis/shields/shields.py b/src/llama_stack/apis/shields/shields.py index 565e1db15..66cd297f5 100644 --- a/src/llama_stack/apis/shields/shields.py +++ b/src/llama_stack/apis/shields/shields.py @@ -8,9 +8,9 @@ from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/src/llama_stack/apis/tools/rag_tool.py b/src/llama_stack/apis/tools/rag_tool.py index 4e43bb284..feac92878 100644 --- a/src/llama_stack/apis/tools/rag_tool.py +++ b/src/llama_stack/apis/tools/rag_tool.py @@ -11,8 +11,8 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod diff --git a/src/llama_stack/apis/tools/tools.py b/src/llama_stack/apis/tools/tools.py index b13ac2f19..e0e59c49a 100644 --- a/src/llama_stack/apis/tools/tools.py +++ b/src/llama_stack/apis/tools/tools.py @@ -11,9 +11,9 @@ from pydantic import BaseModel from typing_extensions import runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod from .rag_tool import RAGToolRuntime diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index cbb16287b..50dd8a785 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -13,10 +13,10 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from fastapi import Body from pydantic import BaseModel, Field +from llama_stack.apis.common.tracing import trace_protocol from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_stores import VectorStore from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema diff --git a/src/llama_stack/core/routers/__init__.py b/src/llama_stack/core/routers/__init__.py index 204cbb87f..ccc27a963 100644 --- a/src/llama_stack/core/routers/__init__.py +++ b/src/llama_stack/core/routers/__init__.py @@ -93,4 +93,11 @@ async def get_auto_router_impl( impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() + + # Apply tracing to router implementation if telemetry is enabled and protocol wants tracing + if run_config.telemetry.enabled and getattr(impl.__class__, "__trace_protocol__", False): + from llama_stack.core.telemetry.trace_protocol import trace_protocol + + trace_protocol(impl.__class__) + return impl diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index 1ba43724d..3a4fbe2b3 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -163,47 +163,6 @@ class MetricEvent(EventCommon): unit: str -@json_schema_type -class MetricInResponse(BaseModel): - """A metric value included in API responses. - :param metric: The name of the metric - :param value: The numeric value of the metric - :param unit: (Optional) The unit of measurement for the metric value - """ - - metric: str - value: int | float - unit: str | None = None - - -# This is a short term solution to allow inference API to return metrics -# The ideal way to do this is to have a way for all response types to include metrics -# and all metric events logged to the telemetry API to be included with the response -# To do this, we will need to augment all response types with a metrics field. -# We have hit a blocker from stainless SDK that prevents us from doing this. -# The blocker is that if we were to augment the response types that have a data field -# in them like so -# class ListModelsResponse(BaseModel): -# metrics: Optional[List[MetricEvent]] = None -# data: List[Models] -# ... -# The client SDK will need to access the data by using a .data field, which is not -# ergonomic. Stainless SDK does support unwrapping the response type, but it -# requires that the response type to only have a single field. - -# We will need a way in the client SDK to signal that the metrics are needed -# and if they are needed, the client SDK has to return the full response type -# without unwrapping it. - - -class MetricResponseMixin(BaseModel): - """Mixin class for API responses that can include metrics. - :param metrics: (Optional) List of metrics associated with the API response - """ - - metrics: list[MetricInResponse] | None = None - - @json_schema_type class StructuredLogType(Enum): """The type of structured log event payload.