From e9d5a7735b605348ac98d7ded5090fefc616cae3 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 22 Nov 2024 15:04:37 -0800 Subject: [PATCH] use global state in open telemetry provider --- .../telemetry/opentelemetry/opentelemetry.py | 78 +++++++++---------- 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index b520f078d..01edd1692 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -23,6 +23,15 @@ from llama_stack.apis.telemetry import * # noqa: F403 from .config import OpenTelemetryConfig +# Add global storage +_GLOBAL_STORAGE = { + "active_spans": {}, + "counters": {}, + "gauges": {}, + "up_down_counters": {}, +} +_global_lock = threading.Lock() + def string_to_trace_id(s: str) -> int: # Convert the string to bytes and then to an integer @@ -67,12 +76,7 @@ class OpenTelemetryAdapter(Telemetry): ) metrics.set_meter_provider(metric_provider) self.meter = metrics.get_meter(__name__) - # Initialize metric storage - self._counters = {} - self._gauges = {} - self._up_down_counters = {} - self._active_spans = {} - self._lock = threading.Lock() + self._lock = _global_lock async def initialize(self) -> None: pass @@ -92,12 +96,11 @@ class OpenTelemetryAdapter(Telemetry): def _log_unstructured(self, event: UnstructuredLogEvent) -> None: with self._lock: - # Check if there's an existing span in the cache + # Use global storage instead of instance storage span_id = string_to_span_id(event.span_id) - span = self._active_spans.get(span_id) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: - # Use existing span timestamp_ns = int(event.timestamp.timestamp() * 1e9) span.add_event( name=event.message, @@ -110,22 +113,22 @@ class OpenTelemetryAdapter(Telemetry): ) def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: - if name not in self._counters: - self._counters[name] = self.meter.create_counter( + if name not in _GLOBAL_STORAGE["counters"]: + _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( name=name, unit=unit, description=f"Counter for {name}", ) - return self._counters[name] + return _GLOBAL_STORAGE["counters"][name] def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: - if name not in self._gauges: - self._gauges[name] = self.meter.create_gauge( + if name not in _GLOBAL_STORAGE["gauges"]: + _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( name=name, unit=unit, description=f"Gauge for {name}", ) - return self._gauges[name] + return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: if isinstance(event.value, int): @@ -140,56 +143,51 @@ class OpenTelemetryAdapter(Telemetry): def _get_or_create_up_down_counter( self, name: str, unit: str ) -> metrics.UpDownCounter: - if name not in self._up_down_counters: - self._up_down_counters[name] = self.meter.create_up_down_counter( - name=name, - unit=unit, - description=f"UpDownCounter for {name}", + if name not in _GLOBAL_STORAGE["up_down_counters"]: + _GLOBAL_STORAGE["up_down_counters"][name] = ( + self.meter.create_up_down_counter( + name=name, + unit=unit, + description=f"UpDownCounter for {name}", + ) ) - return self._up_down_counters[name] + return _GLOBAL_STORAGE["up_down_counters"][name] def _log_structured(self, event: StructuredLogEvent) -> None: + with self._lock: - trace_id = string_to_trace_id(event.trace_id) span_id = string_to_span_id(event.span_id) tracer = trace.get_tracer(__name__) - span_context = trace.SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=True, - trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), - trace_state=trace.TraceState(), - ) if isinstance(event.payload, SpanStartPayload): - # Get parent span if it exists + # Find parent span from active spans parent_span = None - for active_span in self._active_spans.values(): - if active_span.is_recording(): - parent_span = active_span - break + if event.payload.parent_span_id: + parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - # Create the context properly + # Create context with parent span if it exists context = trace.Context() if parent_span: context = trace.set_span_in_context(parent_span) + # Create new span span = tracer.start_span( name=event.payload.name, context=context, attributes=event.attributes or {}, start_time=int(event.timestamp.timestamp() * 1e9), ) - self._active_spans[span_id] = span + _GLOBAL_STORAGE["active_spans"][span_id] = span - # Set the span as current + # Set as current span _ = trace.set_span_in_context(span) trace.use_span(span, end_on_exit=False) elif isinstance(event.payload, SpanEndPayload): - # Retrieve and end the existing span - span = self._active_spans.get(span_id) + # End existing span + span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: if event.attributes: span.set_attributes(event.attributes) @@ -203,7 +201,7 @@ class OpenTelemetryAdapter(Telemetry): span.end(end_time=int(event.timestamp.timestamp() * 1e9)) # Remove from active spans - del self._active_spans[span_id] + del _GLOBAL_STORAGE["active_spans"][span_id] async def get_trace(self, trace_id: str) -> Trace: raise NotImplementedError("Trace retrieval not implemented yet")