Fix telemetry to work on reinstantiating new lib cli (#761)

# What does this PR do?

Since we maintain global state in our telemetry pipeline,
reinstantiating lib cli will cause us to add duplicate span processors
causing sqlite to lock out because of constraint violations since we now
have two span processor writing to sqlite. This PR changes the telemetry
adapter for otel to only instantiate the provider once and add the span
processsors only once.

Also fixes an issue llama stack build


## Test Plan

tested with notebook at
https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d#scrollTo=9496f75c
This commit is contained in:
Dinesh Yeduguru 2025-01-14 11:31:50 -08:00 committed by GitHub
parent 194d12b304
commit a174938fbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 30 additions and 31 deletions

View file

@ -4,9 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
import importlib.resources import importlib.resources
import os import os
import shutil import shutil
from functools import lru_cache from functools import lru_cache
@ -14,14 +12,12 @@ from pathlib import Path
from typing import List, Optional from typing import List, Optional
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BuildConfig, BuildConfig,
DistributionSpec, DistributionSpec,
Provider, Provider,
StackRunConfig, StackRunConfig,
) )
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -296,6 +292,7 @@ class StackBuild(Subcommand):
/ f"templates/{template_name}/run.yaml" / f"templates/{template_name}/run.yaml"
) )
with importlib.resources.as_file(template_path) as path: with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{build_config.name}-run.yaml"
shutil.copy(path, run_config_file) shutil.copy(path, run_config_file)
# Find all ${env.VARIABLE} patterns # Find all ${env.VARIABLE} patterns
cprint("Build Successful!", color="green") cprint("Build Successful!", color="green")

View file

@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import (
Trace, Trace,
UnstructuredLogEvent, UnstructuredLogEvent,
) )
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor, ConsoleSpanProcessor,
) )
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor, SQLiteSpanProcessor,
) )
@ -52,6 +49,7 @@ _GLOBAL_STORAGE = {
"up_down_counters": {}, "up_down_counters": {},
} }
_global_lock = threading.Lock() _global_lock = threading.Lock()
_TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int: def string_to_trace_id(s: str) -> int:
@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
} }
) )
provider = TracerProvider(resource=resource) global _TRACER_PROVIDER
trace.set_tracer_provider(provider) if _TRACER_PROVIDER is None:
if TelemetrySink.OTEL in self.config.sinks: provider = TracerProvider(resource=resource)
otlp_exporter = OTLPSpanExporter( trace.set_tracer_provider(provider)
endpoint=self.config.otel_endpoint, _TRACER_PROVIDER = provider
) if TelemetrySink.OTEL in self.config.sinks:
span_processor = BatchSpanProcessor(otlp_exporter) otlp_exporter = OTLPSpanExporter(
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint, endpoint=self.config.otel_endpoint,
) )
) span_processor = BatchSpanProcessor(otlp_exporter)
metric_provider = MeterProvider( trace.get_tracer_provider().add_span_processor(span_processor)
resource=resource, metric_readers=[metric_reader] metric_reader = PeriodicExportingMetricReader(
) OTLPMetricExporter(
metrics.set_meter_provider(metric_provider) endpoint=self.config.otel_endpoint,
self.meter = metrics.get_meter(__name__) )
if TelemetrySink.SQLITE in self.config.sinks: )
trace.get_tracer_provider().add_span_processor( metric_provider = MeterProvider(
SQLiteSpanProcessor(self.config.sqlite_db_path) resource=resource, metric_readers=[metric_reader]
) )
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) metrics.set_meter_provider(metric_provider)
if TelemetrySink.CONSOLE in self.config.sinks: self.meter = metrics.get_meter(__name__)
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock self._lock = _global_lock
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -127,7 +127,8 @@ class TraceContext:
def setup_logger(api: Telemetry, level: int = logging.INFO): def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER global BACKGROUND_LOGGER
BACKGROUND_LOGGER = BackgroundLogger(api) if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(level) logger.setLevel(level)
logger.addHandler(TelemetryHandler()) logger.addHandler(TelemetryHandler())