use _get_headers_dictionary

This commit is contained in:
Ishaan Jaff 2025-03-18 14:55:39 -07:00
parent 57d08531a1
commit b940c969fd

View file

@ -10,6 +10,7 @@ from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Function,
StandardCallbackDynamicParams,
StandardLoggingPayload,
)
@ -311,6 +312,8 @@ class OpenTelemetry(CustomLogger):
)
_parent_context, parent_otel_span = self._get_span_context(kwargs)
self._add_dynamic_span_processor_if_needed(kwargs)
# Span 1: Requst sent to litellm SDK
span = self.tracer.start_span(
name=self._get_span_name(kwargs),
@ -341,6 +344,43 @@ class OpenTelemetry(CustomLogger):
if parent_otel_span is not None:
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _add_dynamic_span_processor_if_needed(self, kwargs):
"""
Helper method to add a span processor with dynamic headers if needed.
This allows for per-request configuration of telemetry exporters by
extracting headers from standard_callback_dynamic_params.
"""
from opentelemetry import trace
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params")
)
if not standard_callback_dynamic_params:
return
# Extract headers from dynamic params
dynamic_headers = {}
# Handle Arize headers
if standard_callback_dynamic_params.get("arize_space_key"):
dynamic_headers["space_key"] = standard_callback_dynamic_params.get(
"arize_space_key"
)
if standard_callback_dynamic_params.get("arize_api_key"):
dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
"arize_api_key"
)
# Only create a span processor if we have headers to use
if len(dynamic_headers) > 0:
from opentelemetry.sdk.trace import TracerProvider
provider = trace.get_tracer_provider()
if isinstance(provider, TracerProvider):
span_processor = self._get_span_processor(dynamic_headers)
provider.add_span_processor(span_processor)
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
@ -445,12 +485,15 @@ class OpenTelemetry(CustomLogger):
try:
if self.callback_name == "arize":
from litellm.integrations.arize.arize import ArizeLogger
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return
elif self.callback_name == "arize_phoenix":
from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger
ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj)
ArizePhoenixLogger.set_arize_phoenix_attributes(
span, kwargs, response_obj
)
return
elif self.callback_name == "langtrace":
from litellm.integrations.langtrace import LangtraceAttributes
@ -779,7 +822,7 @@ class OpenTelemetry(CustomLogger):
carrier = {"traceparent": traceparent}
return TraceContextTextMapPropagator().extract(carrier=carrier), None
def _get_span_processor(self):
def _get_span_processor(self, dynamic_headers: Optional[dict] = None):
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterGRPC,
)
@ -799,10 +842,9 @@ class OpenTelemetry(CustomLogger):
self.OTEL_ENDPOINT,
self.OTEL_HEADERS,
)
_split_otel_headers = {}
if self.OTEL_HEADERS is not None and isinstance(self.OTEL_HEADERS, str):
_split_otel_headers = self.OTEL_HEADERS.split("=")
_split_otel_headers = {_split_otel_headers[0]: _split_otel_headers[1]}
_split_otel_headers = OpenTelemetry._get_headers_dictionary(
headers=dynamic_headers or self.OTEL_HEADERS
)
if isinstance(self.OTEL_EXPORTER, SpanExporter):
verbose_logger.debug(
@ -844,6 +886,20 @@ class OpenTelemetry(CustomLogger):
)
return BatchSpanProcessor(ConsoleSpanExporter())
@staticmethod
def _get_headers_dictionary(headers: Optional[Union[str, dict]]) -> dict:
"""
Convert a string or dictionary of headers into a dictionary of headers.
"""
_split_otel_headers = {}
if headers:
if isinstance(headers, str):
_split_otel_headers = headers.split("=")
_split_otel_headers = {_split_otel_headers[0]: _split_otel_headers[1]}
elif isinstance(headers, dict):
_split_otel_headers = headers
return _split_otel_headers
async def async_management_endpoint_success_hook(
self,
logging_payload: ManagementEndpointLoggingPayload,