diff --git a/llama_stack/apis/conversations/conversations.py b/llama_stack/apis/conversations/conversations.py index 3b6c50a03..d75683efa 100644 --- a/llama_stack/apis/conversations/conversations.py +++ b/llama_stack/apis/conversations/conversations.py @@ -21,7 +21,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageWebSearchToolCall, ) from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod Metadata = dict[str, str] diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 948ec615f..ae01c5dfc 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -117,8 +117,6 @@ class Api(Enum, metaclass=DynamicApiMeta): post_training = "post_training" tool_runtime = "tool_runtime" - telemetry = "telemetry" - models = "models" shields = "shields" vector_stores = "vector_stores" # only used for routing table diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index f1d3764db..6386f4eca 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.common.responses import Order from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 049482837..7dc565244 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -23,6 +23,7 @@ from llama_stack.apis.common.responses import Order from llama_stack.apis.models import Model from llama_stack.apis.telemetry import MetricResponseMixin from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -30,7 +31,6 @@ from llama_stack.models.llama.datatypes import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod register_schema(ToolCall) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 5486e3bfd..903bd6510 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py index b39c363c7..4651b9294 100644 --- a/llama_stack/apis/prompts/prompts.py +++ b/llama_stack/apis/prompts/prompts.py @@ -11,7 +11,7 @@ from typing import Protocol, runtime_checkable from pydantic import BaseModel, Field, field_validator, model_validator from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index f6d51871b..249473cae 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.shields import Shield from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 5d967cf02..565e1db15 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index ed7847e23..c508721f1 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -12,7 +12,7 @@ from typing_extensions import runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index b6a1a2543..b13ac2f19 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -13,7 +13,7 @@ from typing_extensions import runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod from .rag_tool import RAGToolRuntime diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 49e4df039..6e855ab99 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_stores import VectorStore from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.core.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 06dae7318..728d06ca6 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -15,10 +15,10 @@ import yaml from llama_stack.cli.stack.utils import ImageType from llama_stack.cli.subcommand import Subcommand -from llama_stack.core.datatypes import LoggingConfig, StackRunConfig +from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro -from llama_stack.log import get_logger +from llama_stack.log import LoggingConfig, get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 172bc17b8..d7175100e 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -31,6 +31,7 @@ from llama_stack.core.storage.datatypes import ( StorageBackendType, StorageConfig, ) +from llama_stack.log import LoggingConfig from llama_stack.providers.datatypes import Api, ProviderSpec LLAMA_STACK_BUILD_CONFIG_VERSION = 2 @@ -195,14 +196,6 @@ class TelemetryConfig(BaseModel): enabled: bool = Field(default=False, description="enable or disable telemetry") -class LoggingConfig(BaseModel): - category_levels: dict[str, str] = Field( - default_factory=dict, - description=""" - Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""", - ) - - class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys uri: str diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index 82cbcf984..9be5ffb49 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import ( logger = get_logger(name=__name__, category="core") -INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations, Api.telemetry} +INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations} def stack_apis() -> list[Api]: diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index c64b9a391..6203b529e 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -32,7 +32,7 @@ from termcolor import cprint from llama_stack.core.build import print_pip_install_help from llama_stack.core.configure import parse_and_maybe_upgrade_config -from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec +from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec from llama_stack.core.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, @@ -44,11 +44,12 @@ from llama_stack.core.stack import ( get_stack_run_config_from_distro, replace_env_vars, ) +from llama_stack.core.telemetry import Telemetry +from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.exec import in_notebook from llama_stack.log import get_logger, setup_logging -from llama_stack.providers.utils.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace from llama_stack.strong_typing.inspection import is_unwrapped_body_param logger = get_logger(name=__name__, category="core") @@ -293,8 +294,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): raise _e assert self.impls is not None - if Api.telemetry in self.impls: - setup_logger(self.impls[Api.telemetry]) + if self.config.telemetry.enabled: + setup_logger(Telemetry()) if not os.environ.get("PYTEST_CURRENT_TEST"): console = Console() diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 0b63815ea..805d260fc 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -27,7 +27,6 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields -from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_stores import VectorStore @@ -49,7 +48,6 @@ from llama_stack.providers.datatypes import ( Api, BenchmarksProtocolPrivate, DatasetsProtocolPrivate, - InlineProviderSpec, ModelsProtocolPrivate, ProviderSpec, RemoteProviderConfig, @@ -98,7 +96,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.files: Files, Api.prompts: Prompts, Api.conversations: Conversations, - Api.telemetry: Telemetry, } if external_apis: @@ -241,24 +238,6 @@ def validate_and_prepare_providers( key = api_str if api not in router_apis else f"inner-{api_str}" providers_with_specs[key] = specs - # TODO: remove this logic, telemetry should not have providers. - # if telemetry has been enabled in the config initialize our internal impl - # telemetry is not an external API so it SHOULD NOT be auto-routed. - if run_config.telemetry.enabled: - specs = {} - p = InlineProviderSpec( - api=Api.telemetry, - provider_type="inline::meta-reference", - pip_packages=[], - optional_api_dependencies=[Api.datasetio], - module="llama_stack.providers.inline.telemetry.meta_reference", - config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", - description="Meta's reference implementation of telemetry and observability using OpenTelemetry.", - ) - spec = ProviderWithSpec(spec=p, provider_type="inline::meta-reference", provider_id="meta-reference") - specs["meta-reference"] = spec - providers_with_specs["telemetry"] = specs - return providers_with_specs diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index 2f35fe04f..204cbb87f 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -72,14 +72,6 @@ async def get_auto_router_impl( raise ValueError(f"API {api.value} not found in router map") api_to_dep_impl = {} - if run_config.telemetry.enabled: - api_to_deps = { - "inference": {"telemetry": Api.telemetry}, - } - for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): - if dep_api in deps: - api_to_dep_impl[dep_name] = deps[dep_api] - # TODO: move pass configs to routers instead if api == Api.inference: inference_ref = run_config.storage.stores.inference @@ -92,6 +84,7 @@ async def get_auto_router_impl( ) await inference_store.initialize() api_to_dep_impl["store"] = inference_store + api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled elif api == Api.vector_io: api_to_dep_impl["vector_stores_config"] = run_config.vector_stores diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 09241d836..d532bc622 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -53,13 +53,13 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletionContentPartTextParam, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry +from llama_stack.apis.telemetry import MetricEvent, MetricInResponse +from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span logger = get_logger(name=__name__, category="core::routers") @@ -70,14 +70,14 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, - telemetry: Telemetry | None = None, store: InferenceStore | None = None, + telemetry_enabled: bool = False, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table - self.telemetry = telemetry + self.telemetry_enabled = telemetry_enabled self.store = store - if self.telemetry: + if self.telemetry_enabled: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) @@ -159,7 +159,7 @@ class InferenceRouter(Inference): model: Model, ) -> list[MetricInResponse]: metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) - if self.telemetry: + if self.telemetry_enabled: for metric in metrics: enqueue_event(metric) return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] @@ -223,7 +223,7 @@ class InferenceRouter(Inference): # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. response = await provider.openai_completion(params) - if self.telemetry: + if self.telemetry_enabled: metrics = self._construct_metrics( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, @@ -285,7 +285,7 @@ class InferenceRouter(Inference): if self.store: asyncio.create_task(self.store.store_chat_completion(response, params.messages)) - if self.telemetry: + if self.telemetry_enabled: metrics = self._construct_metrics( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, @@ -393,7 +393,7 @@ class InferenceRouter(Inference): else: if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled: complete = True completion_tokens = await self._count_tokens(completion_text) # if we are done receiving tokens @@ -401,7 +401,7 @@ class InferenceRouter(Inference): total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) # Create a separate span for streaming completion metrics - if self.telemetry: + if self.telemetry_enabled: # Log metrics in the new span context completion_metrics = self._construct_metrics( prompt_tokens=prompt_tokens, @@ -450,7 +450,7 @@ class InferenceRouter(Inference): total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) # Create a separate span for completion metrics - if self.telemetry: + if self.telemetry_enabled: # Log metrics in the new span context completion_metrics = self._construct_metrics( prompt_tokens=prompt_tokens, @@ -548,7 +548,7 @@ class InferenceRouter(Inference): completion_text += "".join(choice_data["content_parts"]) # Add metrics to the chunk - if self.telemetry and hasattr(chunk, "usage") and chunk.usage: + if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage: metrics = self._construct_metrics( prompt_tokens=chunk.usage.prompt_tokens, completion_tokens=chunk.usage.completion_tokens, diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 845686f15..80505c3f9 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -36,7 +36,6 @@ from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.core.access_control.access_control import AccessDeniedError from llama_stack.core.datatypes import ( AuthenticationRequiredError, - LoggingConfig, StackRunConfig, process_cors_config, ) @@ -53,19 +52,13 @@ from llama_stack.core.stack import ( cast_image_name_to_string, replace_env_vars, ) +from llama_stack.core.telemetry import Telemetry +from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, setup_logger from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro from llama_stack.core.utils.context import preserve_contexts_async_generator -from llama_stack.log import get_logger, setup_logging +from llama_stack.log import LoggingConfig, get_logger, setup_logging from llama_stack.providers.datatypes import Api -from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig -from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( - TelemetryAdapter, -) -from llama_stack.providers.utils.telemetry.tracing import ( - CURRENT_TRACE_CONTEXT, - setup_logger, -) from .auth import AuthenticationMiddleware from .quota import QuotaMiddleware @@ -451,9 +444,7 @@ def create_app() -> StackApp: app.add_middleware(CORSMiddleware, **cors_config.model_dump()) if config.telemetry.enabled: - setup_logger(impls[Api.telemetry]) - else: - setup_logger(TelemetryAdapter(TelemetryConfig(), {})) + setup_logger(Telemetry()) # Load external APIs if configured external_apis = load_external_apis(config) @@ -511,7 +502,8 @@ def create_app() -> StackApp: app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) + if config.telemetry.enabled: + app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) return app diff --git a/llama_stack/core/server/tracing.py b/llama_stack/core/server/tracing.py index 4c6df5b42..c4901d9b1 100644 --- a/llama_stack/core/server/tracing.py +++ b/llama_stack/core/server/tracing.py @@ -7,8 +7,8 @@ from aiohttp import hdrs from llama_stack.core.external import ExternalApiSpec from llama_stack.core.server.routes import find_matching_route, initialize_route_impls +from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger -from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace logger = get_logger(name=__name__, category="core::server") diff --git a/llama_stack/core/telemetry/__init__.py b/llama_stack/core/telemetry/__init__.py new file mode 100644 index 000000000..bab612c0d --- /dev/null +++ b/llama_stack/core/telemetry/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .telemetry import Telemetry +from .trace_protocol import serialize_value, trace_protocol +from .tracing import ( + CURRENT_TRACE_CONTEXT, + ROOT_SPAN_MARKERS, + end_trace, + enqueue_event, + get_current_span, + setup_logger, + span, + start_trace, +) + +__all__ = [ + "Telemetry", + "trace_protocol", + "serialize_value", + "CURRENT_TRACE_CONTEXT", + "ROOT_SPAN_MARKERS", + "end_trace", + "enqueue_event", + "get_current_span", + "setup_logger", + "span", + "start_trace", +] diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/core/telemetry/telemetry.py similarity index 96% rename from llama_stack/providers/inline/telemetry/meta_reference/telemetry.py rename to llama_stack/core/telemetry/telemetry.py index b15b1e490..f0cec08ec 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/core/telemetry/telemetry.py @@ -24,14 +24,13 @@ from llama_stack.apis.telemetry import ( SpanStartPayload, SpanStatus, StructuredLogEvent, - Telemetry, UnstructuredLogEvent, ) -from llama_stack.core.datatypes import Api +from llama_stack.apis.telemetry import ( + Telemetry as TelemetryBase, +) +from llama_stack.core.telemetry.tracing import ROOT_SPAN_MARKERS from llama_stack.log import get_logger -from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS - -from .config import TelemetryConfig _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { "active_spans": {}, @@ -50,9 +49,8 @@ def is_tracing_enabled(tracer): return span.is_recording() -class TelemetryAdapter(Telemetry): - def __init__(self, _config: TelemetryConfig, deps: dict[Api, Any]) -> None: - self.datasetio_api = deps.get(Api.datasetio) +class Telemetry(TelemetryBase): + def __init__(self) -> None: self.meter = None global _TRACER_PROVIDER diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/core/telemetry/trace_protocol.py similarity index 78% rename from llama_stack/providers/utils/telemetry/trace_protocol.py rename to llama_stack/core/telemetry/trace_protocol.py index e9320b7a8..807b8e2a9 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/core/telemetry/trace_protocol.py @@ -9,27 +9,29 @@ import inspect import json from collections.abc import AsyncGenerator, Callable from functools import wraps -from typing import Any +from typing import Any, cast from pydantic import BaseModel from llama_stack.models.llama.datatypes import Primitive +type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"] -def serialize_value(value: Any) -> Primitive: + +def serialize_value(value: Any) -> str: return str(_prepare_for_json(value)) -def _prepare_for_json(value: Any) -> str: +def _prepare_for_json(value: Any) -> JSONValue: """Serialize a single value into JSON-compatible format.""" if value is None: return "" elif isinstance(value, str | int | float | bool): return value elif hasattr(value, "_name_"): - return value._name_ + return cast(str, value._name_) elif isinstance(value, BaseModel): - return json.loads(value.model_dump_json()) + return cast(JSONValue, json.loads(value.model_dump_json())) elif isinstance(value, list | tuple | set): return [_prepare_for_json(item) for item in value] elif isinstance(value, dict): @@ -37,35 +39,35 @@ def _prepare_for_json(value: Any) -> str: else: try: json.dumps(value) - return value + return cast(JSONValue, value) except Exception: return str(value) -def trace_protocol[T](cls: type[T]) -> type[T]: +def trace_protocol[T: type[Any]](cls: T) -> T: """ A class decorator that automatically traces all methods in a protocol/base class and its inheriting classes. """ - def trace_method(method: Callable) -> Callable: + def trace_method(method: Callable[..., Any]) -> Callable[..., Any]: is_async = asyncio.iscoroutinefunction(method) is_async_gen = inspect.isasyncgenfunction(method) - def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: + def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]: class_name = self.__class__.__name__ method_name = method.__name__ span_type = "async_generator" if is_async_gen else "async" if is_async else "sync" sig = inspect.signature(method) param_names = list(sig.parameters.keys())[1:] # Skip 'self' - combined_args = {} + combined_args: dict[str, str] = {} for i, arg in enumerate(args): param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}" combined_args[param_name] = serialize_value(arg) for k, v in kwargs.items(): combined_args[str(k)] = serialize_value(v) - span_attributes = { + span_attributes: dict[str, Primitive] = { "__autotraced__": True, "__class__": class_name, "__method__": method_name, @@ -76,8 +78,8 @@ def trace_protocol[T](cls: type[T]) -> type[T]: return class_name, method_name, span_attributes @wraps(method) - async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator: - from llama_stack.providers.utils.telemetry import tracing + async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: + from llama_stack.core.telemetry import tracing class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs) @@ -92,7 +94,7 @@ def trace_protocol[T](cls: type[T]) -> type[T]: @wraps(method) async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - from llama_stack.providers.utils.telemetry import tracing + from llama_stack.core.telemetry import tracing class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs) @@ -107,7 +109,7 @@ def trace_protocol[T](cls: type[T]) -> type[T]: @wraps(method) def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - from llama_stack.providers.utils.telemetry import tracing + from llama_stack.core.telemetry import tracing class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs) @@ -127,16 +129,17 @@ def trace_protocol[T](cls: type[T]) -> type[T]: else: return sync_wrapper - original_init_subclass = getattr(cls, "__init_subclass__", None) + original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None)) - def __init_subclass__(cls_child, **kwargs): # noqa: N807 + def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807 if original_init_subclass: - original_init_subclass(**kwargs) + cast(Callable[..., None], original_init_subclass)(**kwargs) for name, method in vars(cls_child).items(): if inspect.isfunction(method) and not name.startswith("_"): setattr(cls_child, name, trace_method(method)) # noqa: B010 - cls.__init_subclass__ = classmethod(__init_subclass__) + cls_any = cast(Any, cls) + cls_any.__init_subclass__ = classmethod(__init_subclass__) return cls diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/core/telemetry/tracing.py similarity index 87% rename from llama_stack/providers/utils/telemetry/tracing.py rename to llama_stack/core/telemetry/tracing.py index 62cceb13e..7742ea0f4 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/core/telemetry/tracing.py @@ -15,7 +15,7 @@ import time from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import Any +from typing import Any, Self from llama_stack.apis.telemetry import ( Event, @@ -28,8 +28,8 @@ from llama_stack.apis.telemetry import ( Telemetry, UnstructuredLogEvent, ) +from llama_stack.core.telemetry.trace_protocol import serialize_value from llama_stack.log import get_logger -from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value logger = get_logger(__name__, category="core") @@ -89,9 +89,6 @@ def generate_trace_id() -> str: return trace_id_to_str(trace_id) -CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None) -BACKGROUND_LOGGER = None - LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0 @@ -104,7 +101,7 @@ class BackgroundLogger: self._last_queue_full_log_time: float = 0.0 self._dropped_since_last_notice: int = 0 - def log_event(self, event): + def log_event(self, event: Event) -> None: try: self.log_queue.put_nowait(event) except queue.Full: @@ -137,10 +134,13 @@ class BackgroundLogger: finally: self.log_queue.task_done() - def __del__(self): + def __del__(self) -> None: self.log_queue.join() +BACKGROUND_LOGGER: BackgroundLogger | None = None + + def enqueue_event(event: Event) -> None: """Enqueue a telemetry event to the background logger if available. @@ -155,13 +155,12 @@ def enqueue_event(event: Event) -> None: class TraceContext: - spans: list[Span] = [] - def __init__(self, logger: BackgroundLogger, trace_id: str): self.logger = logger self.trace_id = trace_id + self.spans: list[Span] = [] - def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span: + def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span: current_span = self.get_current_span() span = Span( span_id=generate_span_id(), @@ -188,7 +187,7 @@ class TraceContext: self.spans.append(span) return span - def pop_span(self, status: SpanStatus = SpanStatus.OK): + def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None: span = self.spans.pop() if span is not None: self.logger.log_event( @@ -203,10 +202,15 @@ class TraceContext: ) ) - def get_current_span(self): + def get_current_span(self) -> Span | None: return self.spans[-1] if self.spans else None +CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar( + "trace_context", default=None +) + + def setup_logger(api: Telemetry, level: int = logging.INFO): global BACKGROUND_LOGGER @@ -217,12 +221,12 @@ def setup_logger(api: Telemetry, level: int = logging.INFO): root_logger.addHandler(TelemetryHandler()) -async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext: +async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None: global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: logger.debug("No Telemetry implementation set. Skipping trace initialization...") - return + return None trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) @@ -269,7 +273,7 @@ def severity(levelname: str) -> LogSeverity: # TODO: ideally, the actual emitting should be done inside a separate daemon # process completely isolated from the server class TelemetryHandler(logging.Handler): - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: # horrendous hack to avoid logging from asyncio and getting into an infinite loop if record.module in ("asyncio", "selector_events"): return @@ -293,17 +297,17 @@ class TelemetryHandler(logging.Handler): ) ) - def close(self): + def close(self) -> None: pass class SpanContextManager: - def __init__(self, name: str, attributes: dict[str, Any] = None): + def __init__(self, name: str, attributes: dict[str, Any] | None = None): self.name = name self.attributes = attributes - self.span = None + self.span: Span | None = None - def __enter__(self): + def __enter__(self) -> Self: global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: @@ -313,7 +317,7 @@ class SpanContextManager: self.span = context.push_span(self.name, self.attributes) return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: @@ -322,13 +326,13 @@ class SpanContextManager: context.pop_span() - def set_attribute(self, key: str, value: Any): + def set_attribute(self, key: str, value: Any) -> None: if self.span: if self.span.attributes is None: self.span.attributes = {} self.span.attributes[key] = serialize_value(value) - async def __aenter__(self): + async def __aenter__(self) -> Self: global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: @@ -338,7 +342,7 @@ class SpanContextManager: self.span = context.push_span(self.name, self.attributes) return self - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type, exc_value, traceback) -> None: global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if not context: @@ -347,19 +351,19 @@ class SpanContextManager: context.pop_span() - def __call__(self, func: Callable): + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def sync_wrapper(*args, **kwargs): + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: with self: return func(*args, **kwargs) @wraps(func) - async def async_wrapper(*args, **kwargs): + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: async with self: return await func(*args, **kwargs) @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: if asyncio.iscoroutinefunction(func): return async_wrapper(*args, **kwargs) else: @@ -368,7 +372,7 @@ class SpanContextManager: return wrapper -def span(name: str, attributes: dict[str, Any] = None): +def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager: return SpanContextManager(name, attributes) diff --git a/llama_stack/log.py b/llama_stack/log.py index 15e628cc3..c11c2c06f 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -9,15 +9,23 @@ import os import re from logging.config import dictConfig # allow-direct-logging +from pydantic import BaseModel, Field from rich.console import Console from rich.errors import MarkupError from rich.logging import RichHandler -from llama_stack.core.datatypes import LoggingConfig - # Default log level DEFAULT_LOG_LEVEL = logging.INFO + +class LoggingConfig(BaseModel): + category_levels: dict[str, str] = Field( + default_factory=dict, + description=""" +Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""", + ) + + # Predefined categories CATEGORIES = [ "core", diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 96f271669..9fd3f7d0e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -67,6 +67,7 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule +from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -78,7 +79,6 @@ from llama_stack.providers.utils.inference.openai_compat import ( convert_tooldef_to_openai_tool, ) from llama_stack.providers.utils.kvstore import KVStore -from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index e80ffcdd1..f0bafff21 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -65,9 +65,9 @@ from llama_stack.apis.inference import ( OpenAIChoice, OpenAIMessageParam, ) +from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str -from llama_stack.providers.utils.telemetry import tracing from .types import ChatCompletionContext, ChatCompletionResult from .utils import ( diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 659dc599e..8e0dc9ecb 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -37,8 +37,8 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger -from llama_stack.providers.utils.telemetry import tracing from .types import ChatCompletionContext, ToolExecutionResult diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 8f3ecf5c9..9baf5a14d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -8,8 +8,8 @@ import asyncio from llama_stack.apis.inference import Message from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel +from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger -from llama_stack.providers.utils.telemetry import tracing log = get_logger(name=__name__, category="agents::meta_reference") diff --git a/llama_stack/providers/inline/telemetry/__init__.py b/llama_stack/providers/inline/telemetry/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/telemetry/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py deleted file mode 100644 index 21743b653..000000000 --- a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.core.datatypes import Api - -from .config import TelemetryConfig, TelemetrySink - -__all__ = ["TelemetryConfig", "TelemetrySink"] - - -async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]): - from .telemetry import TelemetryAdapter - - impl = TelemetryAdapter(config, deps) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py deleted file mode 100644 index 088dd8439..000000000 --- a/llama_stack/providers/inline/telemetry/meta_reference/config.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import StrEnum -from typing import Any - -from pydantic import BaseModel, Field, field_validator - - -class TelemetrySink(StrEnum): - OTEL_TRACE = "otel_trace" - OTEL_METRIC = "otel_metric" - CONSOLE = "console" - - -class TelemetryConfig(BaseModel): - otel_exporter_otlp_endpoint: str | None = Field( - default=None, - description="The OpenTelemetry collector endpoint URL (base URL for traces, metrics, and logs). If not set, the SDK will use OTEL_EXPORTER_OTLP_ENDPOINT environment variable.", - ) - service_name: str = Field( - # service name is always the same, use zero-width space to avoid clutter - default="\u200b", - description="The service name to use for telemetry", - ) - sinks: list[TelemetrySink] = Field( - default_factory=list, - description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, console)", - ) - - @field_validator("sinks", mode="before") - @classmethod - def validate_sinks(cls, v): - if isinstance(v, str): - return [TelemetrySink(sink.strip()) for sink in v.split(",")] - return v or [] - - @classmethod - def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: - return { - "service_name": "${env.OTEL_SERVICE_NAME:=\u200b}", - "sinks": "${env.TELEMETRY_SINKS:=}", - "otel_exporter_otlp_endpoint": "${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}", - } diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 2c051719b..b31f1f5e8 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -22,11 +22,11 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.models import Model from llama_stack.apis.models.models import ModelType +from llama_stack.core.telemetry.tracing import get_current_span from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params -from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="providers::remote::watsonx") diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 42b89f897..3eef1f272 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -256,7 +256,7 @@ class LiteLLMOpenAIMixin( params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: # Add usage tracking for streaming when telemetry is active - from llama_stack.providers.utils.telemetry.tracing import get_current_span + from llama_stack.core.telemetry.tracing import get_current_span stream_options = params.stream_options if params.stream and get_current_span() is not None: diff --git a/llama_stack/providers/utils/telemetry/__init__.py b/llama_stack/providers/utils/telemetry/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/utils/telemetry/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/tests/integration/telemetry/conftest.py b/tests/integration/telemetry/conftest.py index d11f00c9f..b055e47ac 100644 --- a/tests/integration/telemetry/conftest.py +++ b/tests/integration/telemetry/conftest.py @@ -23,7 +23,7 @@ from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -import llama_stack.providers.inline.telemetry.meta_reference.telemetry as telemetry_module +import llama_stack.core.telemetry.telemetry as telemetry_module from llama_stack.testing.api_recorder import patch_httpx_for_test_id from tests.integration.fixtures.common import instantiate_llama_stack_client diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index 3b0643a13..4161d7b84 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -196,8 +196,6 @@ class TestProviderRegistry: assert internal_api not in apis, f"Internal API {internal_api} should not be in providable_apis" for api in apis: - if api == Api.telemetry: - continue module_name = f"llama_stack.providers.registry.{api.name.lower()}" try: importlib.import_module(module_name)