diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 0ec7358037..cfc2847de8 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -10,6 +10,7 @@ from litellm.types.services import ServiceLoggerPayload from litellm.types.utils import ( ChatCompletionMessageToolCall, Function, + StandardCallbackDynamicParams, StandardLoggingPayload, ) @@ -311,6 +312,8 @@ class OpenTelemetry(CustomLogger): ) _parent_context, parent_otel_span = self._get_span_context(kwargs) + self._add_dynamic_span_processor_if_needed(kwargs) + # Span 1: Requst sent to litellm SDK span = self.tracer.start_span( name=self._get_span_name(kwargs), @@ -341,6 +344,43 @@ class OpenTelemetry(CustomLogger): if parent_otel_span is not None: parent_otel_span.end(end_time=self._to_ns(datetime.now())) + def _add_dynamic_span_processor_if_needed(self, kwargs): + """ + Helper method to add a span processor with dynamic headers if needed. + + This allows for per-request configuration of telemetry exporters by + extracting headers from standard_callback_dynamic_params. + """ + from opentelemetry import trace + + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params") + ) + if not standard_callback_dynamic_params: + return + + # Extract headers from dynamic params + dynamic_headers = {} + + # Handle Arize headers + if standard_callback_dynamic_params.get("arize_space_key"): + dynamic_headers["space_key"] = standard_callback_dynamic_params.get( + "arize_space_key" + ) + if standard_callback_dynamic_params.get("arize_api_key"): + dynamic_headers["api_key"] = standard_callback_dynamic_params.get( + "arize_api_key" + ) + + # Only create a span processor if we have headers to use + if len(dynamic_headers) > 0: + from opentelemetry.sdk.trace import TracerProvider + + provider = trace.get_tracer_provider() + if isinstance(provider, TracerProvider): + span_processor = self._get_span_processor(dynamic_headers) + provider.add_span_processor(span_processor) + def _handle_failure(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode @@ -445,12 +485,15 @@ class OpenTelemetry(CustomLogger): try: if self.callback_name == "arize": from litellm.integrations.arize.arize import ArizeLogger + ArizeLogger.set_arize_attributes(span, kwargs, response_obj) return elif self.callback_name == "arize_phoenix": from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger - ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj) + ArizePhoenixLogger.set_arize_phoenix_attributes( + span, kwargs, response_obj + ) return elif self.callback_name == "langtrace": from litellm.integrations.langtrace import LangtraceAttributes @@ -779,7 +822,7 @@ class OpenTelemetry(CustomLogger): carrier = {"traceparent": traceparent} return TraceContextTextMapPropagator().extract(carrier=carrier), None - def _get_span_processor(self): + def _get_span_processor(self, dynamic_headers: Optional[dict] = None): from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( OTLPSpanExporter as OTLPSpanExporterGRPC, ) @@ -799,10 +842,9 @@ class OpenTelemetry(CustomLogger): self.OTEL_ENDPOINT, self.OTEL_HEADERS, ) - _split_otel_headers = {} - if self.OTEL_HEADERS is not None and isinstance(self.OTEL_HEADERS, str): - _split_otel_headers = self.OTEL_HEADERS.split("=") - _split_otel_headers = {_split_otel_headers[0]: _split_otel_headers[1]} + _split_otel_headers = OpenTelemetry._get_headers_dictionary( + headers=dynamic_headers or self.OTEL_HEADERS + ) if isinstance(self.OTEL_EXPORTER, SpanExporter): verbose_logger.debug( @@ -844,6 +886,20 @@ class OpenTelemetry(CustomLogger): ) return BatchSpanProcessor(ConsoleSpanExporter()) + @staticmethod + def _get_headers_dictionary(headers: Optional[Union[str, dict]]) -> dict: + """ + Convert a string or dictionary of headers into a dictionary of headers. + """ + _split_otel_headers = {} + if headers: + if isinstance(headers, str): + _split_otel_headers = headers.split("=") + _split_otel_headers = {_split_otel_headers[0]: _split_otel_headers[1]} + elif isinstance(headers, dict): + _split_otel_headers = headers + return _split_otel_headers + async def async_management_endpoint_success_hook( self, logging_payload: ManagementEndpointLoggingPayload,