diff --git a/README.md b/README.md
index e13732000..bfdba2fa3 100644
--- a/README.md
+++ b/README.md
@@ -113,7 +113,7 @@ for part in response:
## Logging Observability ([Docs](https://docs.litellm.ai/docs/observability/callbacks))
-LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack
+LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack, MLflow
```python
from litellm import completion
diff --git a/docs/my-website/docs/observability/mlflow.md b/docs/my-website/docs/observability/mlflow.md
new file mode 100644
index 000000000..3b1e1d477
--- /dev/null
+++ b/docs/my-website/docs/observability/mlflow.md
@@ -0,0 +1,108 @@
+# MLflow
+
+## What is MLflow?
+
+**MLflow** is an end-to-end open source MLOps platform for [experiment tracking](https://www.mlflow.org/docs/latest/tracking.html), [model management](https://www.mlflow.org/docs/latest/models.html), [evaluation](https://www.mlflow.org/docs/latest/llms/llm-evaluate/index.html), [observability (tracing)](https://www.mlflow.org/docs/latest/llms/tracing/index.html), and [deployment](https://www.mlflow.org/docs/latest/deployment/index.html). MLflow empowers teams to collaboratively develop and refine LLM applications efficiently.
+
+MLflow’s integration with LiteLLM supports advanced observability compatible with OpenTelemetry.
+
+
+
+
+
+## Getting Started
+
+Install MLflow:
+
+```shell
+pip install mlflow
+```
+
+To enable LiteLLM tracing:
+
+```python
+import mlflow
+
+mlflow.litellm.autolog()
+
+# Alternative, you can set the callback manually in LiteLLM
+# litellm.callbacks = ["mlflow"]
+```
+
+Since MLflow is open-source, no sign-up or API key is needed to log traces!
+
+```
+import litellm
+import os
+
+# Set your LLM provider's API key
+os.environ["OPENAI_API_KEY"] = ""
+
+# Call LiteLLM as usual
+response = litellm.completion(
+ model="gpt-4o-mini",
+ messages=[
+ {"role": "user", "content": "Hi 👋 - i'm openai"}
+ ]
+)
+```
+
+Open the MLflow UI and go to the `Traces` tab to view logged traces:
+
+```bash
+mlflow ui
+```
+
+## Exporting Traces to OpenTelemetry collectors
+
+MLflow traces are compatible with OpenTelemetry. You can export traces to any OpenTelemetry collector (e.g., Jaeger, Zipkin, Datadog, New Relic) by setting the endpoint URL in the environment variables.
+
+```
+# Set the endpoint of the OpenTelemetry Collector
+os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "http://localhost:4317/v1/traces"
+# Optionally, set the service name to group traces
+os.environ["OTEL_SERVICE_NAME"] = ""
+```
+
+See [MLflow documentation](https://mlflow.org/docs/latest/llms/tracing/index.html#using-opentelemetry-collector-for-exporting-traces) for more details.
+
+## Combine LiteLLM Trace with Your Application Trace
+
+LiteLLM is often part of larger LLM applications, such as agentic models. MLflow Tracing allows you to instrument custom Python code, which can then be combined with LiteLLM traces.
+
+```python
+import litellm
+import mlflow
+from mlflow.entities import SpanType
+
+# Enable LiteLLM tracing
+mlflow.litellm.autolog()
+
+
+class CustomAgent:
+ # Use @mlflow.trace to instrument Python functions.
+ @mlflow.trace(span_type=SpanType.AGENT)
+ def run(self, query: str):
+ # do something
+
+ while i < self.max_turns:
+ response = litellm.completion(
+ model="gpt-4o-mini",
+ messages=messages,
+ )
+
+ action = self.get_action(response)
+ ...
+
+ @mlflow.trace
+ def get_action(llm_response):
+ ...
+```
+
+This approach generates a unified trace, combining your custom Python code with LiteLLM calls.
+
+
+## Support
+
+* For advanced usage and integrations of tracing, visit the [MLflow Tracing documentation](https://mlflow.org/docs/latest/llms/tracing/index.html).
+* For any question or issue with this integration, please [submit an issue](https://github.com/mlflow/mlflow/issues/new/choose) on our [Github](https://github.com/mlflow/mlflow) repository!
\ No newline at end of file
diff --git a/docs/my-website/img/mlflow_tracing.png b/docs/my-website/img/mlflow_tracing.png
new file mode 100644
index 000000000..aee1fb79e
Binary files /dev/null and b/docs/my-website/img/mlflow_tracing.png differ
diff --git a/litellm/__init__.py b/litellm/__init__.py
index 9812de1d8..e54117e11 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -57,6 +57,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"gcs_bucket",
"opik",
"argilla",
+ "mlflow",
]
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list(
diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py
new file mode 100644
index 000000000..baf33a86b
--- /dev/null
+++ b/litellm/integrations/mlflow.py
@@ -0,0 +1,246 @@
+import json
+import threading
+from typing import Optional
+
+from litellm.integrations.custom_logger import CustomLogger
+from litellm._logging import verbose_logger
+
+class MlflowLogger(CustomLogger):
+ def __init__(self):
+ from mlflow.tracking import MlflowClient
+
+ self._client = MlflowClient()
+
+ self._stream_id_to_span = {}
+ self._lock = threading.Lock() # lock for _stream_id_to_span
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ def _handle_success(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the success event as an MLflow span.
+ Note that this method is called asynchronously in the background thread.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ try:
+ verbose_logger.debug(f"MLflow logging start for success event")
+
+ if kwargs.get("stream"):
+ self._handle_stream_event(kwargs, response_obj, start_time, end_time)
+ else:
+ span = self._start_span_or_trace(kwargs, start_time)
+ end_time_ns = int(end_time.timestamp() * 1e9)
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+ except Exception:
+ verbose_logger.debug(f"MLflow Logging Error", stack_info=True)
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ def _handle_failure(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the failure event as an MLflow span.
+ Note that this method is called *synchronously* unlike the success handler.
+ """
+ from mlflow.entities import SpanEvent, SpanStatusCode
+
+ try:
+ span = self._start_span_or_trace(kwargs, start_time)
+
+ end_time_ns = int(end_time.timestamp() * 1e9)
+
+ # Record exception info as event
+ if exception := kwargs.get("exception"):
+ span.add_event(SpanEvent.from_exception(exception))
+
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.ERROR,
+ end_time_ns=end_time_ns,
+ )
+
+ except Exception as e:
+ verbose_logger.debug(f"MLflow Logging Error", stack_info=True)
+
+ def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Handle the success event for a streaming response. For streaming calls,
+ log_success_event handle is triggered for every chunk of the stream.
+ We create a single span for the entire stream request as follows:
+
+ 1. For the first chunk, start a new span and store it in the map.
+ 2. For subsequent chunks, add the chunk as an event to the span.
+ 3. For the final chunk, end the span and remove the span from the map.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ litellm_call_id = kwargs.get("litellm_call_id")
+
+ if litellm_call_id not in self._stream_id_to_span:
+ with self._lock:
+ # Check again after acquiring lock
+ if litellm_call_id not in self._stream_id_to_span:
+ # Start a new span for the first chunk of the stream
+ span = self._start_span_or_trace(kwargs, start_time)
+ self._stream_id_to_span[litellm_call_id] = span
+
+ # Add chunk as event to the span
+ span = self._stream_id_to_span[litellm_call_id]
+ self._add_chunk_events(span, response_obj)
+
+ # If this is the final chunk, end the span. The final chunk
+ # has complete_streaming_response that gathers the full response.
+ if final_response := kwargs.get("complete_streaming_response"):
+ end_time_ns = int(end_time.timestamp() * 1e9)
+ self._end_span_or_trace(
+ span=span,
+ outputs=final_response,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+
+ # Remove the stream_id from the map
+ with self._lock:
+ self._stream_id_to_span.pop(litellm_call_id)
+
+ def _add_chunk_events(self, span, response_obj):
+ from mlflow.entities import SpanEvent
+
+ try:
+ for choice in response_obj.choices:
+ span.add_event(
+ SpanEvent(
+ name="streaming_chunk",
+ attributes={"delta": json.dumps(choice.delta.model_dump())},
+ )
+ )
+ except Exception:
+ verbose_logger.debug("Error adding chunk events to span", stack_info=True)
+
+ def _construct_input(self, kwargs):
+ """Construct span inputs with optional parameters"""
+ inputs = {"messages": kwargs.get("messages")}
+ for key in ["functions", "tools", "stream", "tool_choice", "user"]:
+ if value := kwargs.get("optional_params", {}).pop(key, None):
+ inputs[key] = value
+ return inputs
+
+ def _extract_attributes(self, kwargs):
+ """
+ Extract span attributes from kwargs.
+
+ With the latest version of litellm, the standard_logging_object contains
+ canonical information for logging. If it is not present, we extract
+ subset of attributes from other kwargs.
+ """
+ attributes = {
+ "litellm_call_id": kwargs.get("litellm_call_id"),
+ "call_type": kwargs.get("call_type"),
+ "model": kwargs.get("model"),
+ }
+ standard_obj = kwargs.get("standard_logging_object")
+ if standard_obj:
+ attributes.update(
+ {
+ "api_base": standard_obj.get("api_base"),
+ "cache_hit": standard_obj.get("cache_hit"),
+ "usage": {
+ "completion_tokens": standard_obj.get("completion_tokens"),
+ "prompt_tokens": standard_obj.get("prompt_tokens"),
+ "total_tokens": standard_obj.get("total_tokens"),
+ },
+ "raw_llm_response": standard_obj.get("response"),
+ "response_cost": standard_obj.get("response_cost"),
+ "saved_cache_cost": standard_obj.get("saved_cache_cost"),
+ }
+ )
+ else:
+ litellm_params = kwargs.get("litellm_params", {})
+ attributes.update(
+ {
+ "model": kwargs.get("model"),
+ "cache_hit": kwargs.get("cache_hit"),
+ "custom_llm_provider": kwargs.get("custom_llm_provider"),
+ "api_base": litellm_params.get("api_base"),
+ "response_cost": kwargs.get("response_cost"),
+ }
+ )
+ return attributes
+
+ def _get_span_type(self, call_type: Optional[str]) -> str:
+ from mlflow.entities import SpanType
+
+ if call_type in ["completion", "acompletion"]:
+ return SpanType.LLM
+ elif call_type == "embeddings":
+ return SpanType.EMBEDDING
+ else:
+ return SpanType.LLM
+
+ def _start_span_or_trace(self, kwargs, start_time):
+ """
+ Start an MLflow span or a trace.
+
+ If there is an active span, we start a new span as a child of
+ that span. Otherwise, we start a new trace.
+ """
+ import mlflow
+
+ call_type = kwargs.get("call_type", "completion")
+ span_name = f"litellm-{call_type}"
+ span_type = self._get_span_type(call_type)
+ start_time_ns = int(start_time.timestamp() * 1e9)
+
+ inputs = self._construct_input(kwargs)
+ attributes = self._extract_attributes(kwargs)
+
+ if active_span := mlflow.get_current_active_span():
+ return self._client.start_span(
+ name=span_name,
+ request_id=active_span.request_id,
+ parent_id=active_span.span_id,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+ else:
+ return self._client.start_trace(
+ name=span_name,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+
+ def _end_span_or_trace(self, span, outputs, end_time_ns, status):
+ """End an MLflow span or a trace."""
+ if span.parent_id is None:
+ self._client.end_trace(
+ request_id=span.request_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )
+ else:
+ self._client.end_span(
+ request_id=span.request_id,
+ span_id=span.span_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index d2e65742c..66f91abf1 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -28,6 +28,7 @@ from litellm.caching.caching_handler import LLMCachingHandler
from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
+from litellm.integrations.mlflow import MlflowLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging,
@@ -554,6 +555,7 @@ class Logging:
message=f"Model Call Details pre-call: {details_to_log}",
level="info",
)
+
elif isinstance(callback, CustomLogger): # custom logger class
callback.log_pre_api_call(
model=self.model,
@@ -1249,6 +1251,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
+
if (
callback == "openmeter"
and self.model_call_details.get("litellm_params", {}).get(
@@ -2338,6 +2341,14 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
_in_memory_loggers.append(_otel_logger)
return _otel_logger # type: ignore
+ elif logging_integration == "mlflow":
+ for callback in _in_memory_loggers:
+ if isinstance(callback, MlflowLogger):
+ return callback # type: ignore
+
+ _mlflow_logger = MlflowLogger()
+ _in_memory_loggers.append(_mlflow_logger)
+ return _mlflow_logger # type: ignore
def get_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
@@ -2439,6 +2450,12 @@ def get_custom_logger_compatible_class(
and callback.callback_name == "langtrace"
):
return callback
+
+ elif logging_integration == "mlflow":
+ for callback in _in_memory_loggers:
+ if isinstance(callback, MlflowLogger):
+ return callback
+
return None
diff --git a/litellm/tests/test_mlflow.py b/litellm/tests/test_mlflow.py
new file mode 100644
index 000000000..ec23875ea
--- /dev/null
+++ b/litellm/tests/test_mlflow.py
@@ -0,0 +1,29 @@
+import pytest
+
+import litellm
+
+
+def test_mlflow_logging():
+ litellm.success_callback = ["mlflow"]
+ litellm.failure_callback = ["mlflow"]
+
+ litellm.completion(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": "what llm are u"}],
+ max_tokens=10,
+ temperature=0.2,
+ user="test-user",
+ )
+
+@pytest.mark.asyncio()
+async def test_async_mlflow_logging():
+ litellm.success_callback = ["mlflow"]
+ litellm.failure_callback = ["mlflow"]
+
+ await litellm.acompletion(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": "hi test from local arize"}],
+ mock_response="hello",
+ temperature=0.1,
+ user="OTEL_USER",
+ )