Add integration with MLflow Tracing (#6147)

* Add MLflow logger

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Streaming handling

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* lint

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* address comments and fix issues

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* address comments and fix issues

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Move logger construction code

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Add docs

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* async handlers

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* new picture

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
This commit is contained in:
Yuki Watanabe 2024-11-14 00:00:41 +09:00 committed by GitHub
parent 1e097bbfbe
commit 82f405adcb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 402 additions and 1 deletions

View file

@ -113,7 +113,7 @@ for part in response:
## Logging Observability ([Docs](https://docs.litellm.ai/docs/observability/callbacks))
LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack
LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack, MLflow
```python
from litellm import completion

View file

@ -0,0 +1,108 @@
# MLflow
## What is MLflow?
**MLflow** is an end-to-end open source MLOps platform for [experiment tracking](https://www.mlflow.org/docs/latest/tracking.html), [model management](https://www.mlflow.org/docs/latest/models.html), [evaluation](https://www.mlflow.org/docs/latest/llms/llm-evaluate/index.html), [observability (tracing)](https://www.mlflow.org/docs/latest/llms/tracing/index.html), and [deployment](https://www.mlflow.org/docs/latest/deployment/index.html). MLflow empowers teams to collaboratively develop and refine LLM applications efficiently.
MLflows integration with LiteLLM supports advanced observability compatible with OpenTelemetry.
<Image img={require('../../img/mlflow_tracing.png')} />
## Getting Started
Install MLflow:
```shell
pip install mlflow
```
To enable LiteLLM tracing:
```python
import mlflow
mlflow.litellm.autolog()
# Alternative, you can set the callback manually in LiteLLM
# litellm.callbacks = ["mlflow"]
```
Since MLflow is open-source, no sign-up or API key is needed to log traces!
```
import litellm
import os
# Set your LLM provider's API key
os.environ["OPENAI_API_KEY"] = ""
# Call LiteLLM as usual
response = litellm.completion(
model="gpt-4o-mini",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
]
)
```
Open the MLflow UI and go to the `Traces` tab to view logged traces:
```bash
mlflow ui
```
## Exporting Traces to OpenTelemetry collectors
MLflow traces are compatible with OpenTelemetry. You can export traces to any OpenTelemetry collector (e.g., Jaeger, Zipkin, Datadog, New Relic) by setting the endpoint URL in the environment variables.
```
# Set the endpoint of the OpenTelemetry Collector
os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "http://localhost:4317/v1/traces"
# Optionally, set the service name to group traces
os.environ["OTEL_SERVICE_NAME"] = "<your-service-name>"
```
See [MLflow documentation](https://mlflow.org/docs/latest/llms/tracing/index.html#using-opentelemetry-collector-for-exporting-traces) for more details.
## Combine LiteLLM Trace with Your Application Trace
LiteLLM is often part of larger LLM applications, such as agentic models. MLflow Tracing allows you to instrument custom Python code, which can then be combined with LiteLLM traces.
```python
import litellm
import mlflow
from mlflow.entities import SpanType
# Enable LiteLLM tracing
mlflow.litellm.autolog()
class CustomAgent:
# Use @mlflow.trace to instrument Python functions.
@mlflow.trace(span_type=SpanType.AGENT)
def run(self, query: str):
# do something
while i < self.max_turns:
response = litellm.completion(
model="gpt-4o-mini",
messages=messages,
)
action = self.get_action(response)
...
@mlflow.trace
def get_action(llm_response):
...
```
This approach generates a unified trace, combining your custom Python code with LiteLLM calls.
## Support
* For advanced usage and integrations of tracing, visit the [MLflow Tracing documentation](https://mlflow.org/docs/latest/llms/tracing/index.html).
* For any question or issue with this integration, please [submit an issue](https://github.com/mlflow/mlflow/issues/new/choose) on our [Github](https://github.com/mlflow/mlflow) repository!

Binary file not shown.

After

Width:  |  Height:  |  Size: 361 KiB

View file

@ -57,6 +57,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"gcs_bucket",
"opik",
"argilla",
"mlflow",
]
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list(

View file

@ -0,0 +1,246 @@
import json
import threading
from typing import Optional
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_logger
class MlflowLogger(CustomLogger):
def __init__(self):
from mlflow.tracking import MlflowClient
self._client = MlflowClient()
self._stream_id_to_span = {}
self._lock = threading.Lock() # lock for _stream_id_to_span
def log_success_event(self, kwargs, response_obj, start_time, end_time):
self._handle_success(kwargs, response_obj, start_time, end_time)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self._handle_success(kwargs, response_obj, start_time, end_time)
def _handle_success(self, kwargs, response_obj, start_time, end_time):
"""
Log the success event as an MLflow span.
Note that this method is called asynchronously in the background thread.
"""
from mlflow.entities import SpanStatusCode
try:
verbose_logger.debug(f"MLflow logging start for success event")
if kwargs.get("stream"):
self._handle_stream_event(kwargs, response_obj, start_time, end_time)
else:
span = self._start_span_or_trace(kwargs, start_time)
end_time_ns = int(end_time.timestamp() * 1e9)
self._end_span_or_trace(
span=span,
outputs=response_obj,
status=SpanStatusCode.OK,
end_time_ns=end_time_ns,
)
except Exception:
verbose_logger.debug(f"MLflow Logging Error", stack_info=True)
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
self._handle_failure(kwargs, response_obj, start_time, end_time)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
self._handle_failure(kwargs, response_obj, start_time, end_time)
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
"""
Log the failure event as an MLflow span.
Note that this method is called *synchronously* unlike the success handler.
"""
from mlflow.entities import SpanEvent, SpanStatusCode
try:
span = self._start_span_or_trace(kwargs, start_time)
end_time_ns = int(end_time.timestamp() * 1e9)
# Record exception info as event
if exception := kwargs.get("exception"):
span.add_event(SpanEvent.from_exception(exception))
self._end_span_or_trace(
span=span,
outputs=response_obj,
status=SpanStatusCode.ERROR,
end_time_ns=end_time_ns,
)
except Exception as e:
verbose_logger.debug(f"MLflow Logging Error", stack_info=True)
def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
"""
Handle the success event for a streaming response. For streaming calls,
log_success_event handle is triggered for every chunk of the stream.
We create a single span for the entire stream request as follows:
1. For the first chunk, start a new span and store it in the map.
2. For subsequent chunks, add the chunk as an event to the span.
3. For the final chunk, end the span and remove the span from the map.
"""
from mlflow.entities import SpanStatusCode
litellm_call_id = kwargs.get("litellm_call_id")
if litellm_call_id not in self._stream_id_to_span:
with self._lock:
# Check again after acquiring lock
if litellm_call_id not in self._stream_id_to_span:
# Start a new span for the first chunk of the stream
span = self._start_span_or_trace(kwargs, start_time)
self._stream_id_to_span[litellm_call_id] = span
# Add chunk as event to the span
span = self._stream_id_to_span[litellm_call_id]
self._add_chunk_events(span, response_obj)
# If this is the final chunk, end the span. The final chunk
# has complete_streaming_response that gathers the full response.
if final_response := kwargs.get("complete_streaming_response"):
end_time_ns = int(end_time.timestamp() * 1e9)
self._end_span_or_trace(
span=span,
outputs=final_response,
status=SpanStatusCode.OK,
end_time_ns=end_time_ns,
)
# Remove the stream_id from the map
with self._lock:
self._stream_id_to_span.pop(litellm_call_id)
def _add_chunk_events(self, span, response_obj):
from mlflow.entities import SpanEvent
try:
for choice in response_obj.choices:
span.add_event(
SpanEvent(
name="streaming_chunk",
attributes={"delta": json.dumps(choice.delta.model_dump())},
)
)
except Exception:
verbose_logger.debug("Error adding chunk events to span", stack_info=True)
def _construct_input(self, kwargs):
"""Construct span inputs with optional parameters"""
inputs = {"messages": kwargs.get("messages")}
for key in ["functions", "tools", "stream", "tool_choice", "user"]:
if value := kwargs.get("optional_params", {}).pop(key, None):
inputs[key] = value
return inputs
def _extract_attributes(self, kwargs):
"""
Extract span attributes from kwargs.
With the latest version of litellm, the standard_logging_object contains
canonical information for logging. If it is not present, we extract
subset of attributes from other kwargs.
"""
attributes = {
"litellm_call_id": kwargs.get("litellm_call_id"),
"call_type": kwargs.get("call_type"),
"model": kwargs.get("model"),
}
standard_obj = kwargs.get("standard_logging_object")
if standard_obj:
attributes.update(
{
"api_base": standard_obj.get("api_base"),
"cache_hit": standard_obj.get("cache_hit"),
"usage": {
"completion_tokens": standard_obj.get("completion_tokens"),
"prompt_tokens": standard_obj.get("prompt_tokens"),
"total_tokens": standard_obj.get("total_tokens"),
},
"raw_llm_response": standard_obj.get("response"),
"response_cost": standard_obj.get("response_cost"),
"saved_cache_cost": standard_obj.get("saved_cache_cost"),
}
)
else:
litellm_params = kwargs.get("litellm_params", {})
attributes.update(
{
"model": kwargs.get("model"),
"cache_hit": kwargs.get("cache_hit"),
"custom_llm_provider": kwargs.get("custom_llm_provider"),
"api_base": litellm_params.get("api_base"),
"response_cost": kwargs.get("response_cost"),
}
)
return attributes
def _get_span_type(self, call_type: Optional[str]) -> str:
from mlflow.entities import SpanType
if call_type in ["completion", "acompletion"]:
return SpanType.LLM
elif call_type == "embeddings":
return SpanType.EMBEDDING
else:
return SpanType.LLM
def _start_span_or_trace(self, kwargs, start_time):
"""
Start an MLflow span or a trace.
If there is an active span, we start a new span as a child of
that span. Otherwise, we start a new trace.
"""
import mlflow
call_type = kwargs.get("call_type", "completion")
span_name = f"litellm-{call_type}"
span_type = self._get_span_type(call_type)
start_time_ns = int(start_time.timestamp() * 1e9)
inputs = self._construct_input(kwargs)
attributes = self._extract_attributes(kwargs)
if active_span := mlflow.get_current_active_span():
return self._client.start_span(
name=span_name,
request_id=active_span.request_id,
parent_id=active_span.span_id,
span_type=span_type,
inputs=inputs,
attributes=attributes,
start_time_ns=start_time_ns,
)
else:
return self._client.start_trace(
name=span_name,
span_type=span_type,
inputs=inputs,
attributes=attributes,
start_time_ns=start_time_ns,
)
def _end_span_or_trace(self, span, outputs, end_time_ns, status):
"""End an MLflow span or a trace."""
if span.parent_id is None:
self._client.end_trace(
request_id=span.request_id,
outputs=outputs,
status=status,
end_time_ns=end_time_ns,
)
else:
self._client.end_span(
request_id=span.request_id,
span_id=span.span_id,
outputs=outputs,
status=status,
end_time_ns=end_time_ns,
)

View file

@ -28,6 +28,7 @@ from litellm.caching.caching_handler import LLMCachingHandler
from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.mlflow import MlflowLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging,
@ -554,6 +555,7 @@ class Logging:
message=f"Model Call Details pre-call: {details_to_log}",
level="info",
)
elif isinstance(callback, CustomLogger): # custom logger class
callback.log_pre_api_call(
model=self.model,
@ -1249,6 +1251,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if (
callback == "openmeter"
and self.model_call_details.get("litellm_params", {}).get(
@ -2338,6 +2341,14 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
_in_memory_loggers.append(_otel_logger)
return _otel_logger # type: ignore
elif logging_integration == "mlflow":
for callback in _in_memory_loggers:
if isinstance(callback, MlflowLogger):
return callback # type: ignore
_mlflow_logger = MlflowLogger()
_in_memory_loggers.append(_mlflow_logger)
return _mlflow_logger # type: ignore
def get_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
@ -2439,6 +2450,12 @@ def get_custom_logger_compatible_class(
and callback.callback_name == "langtrace"
):
return callback
elif logging_integration == "mlflow":
for callback in _in_memory_loggers:
if isinstance(callback, MlflowLogger):
return callback
return None

View file

@ -0,0 +1,29 @@
import pytest
import litellm
def test_mlflow_logging():
litellm.success_callback = ["mlflow"]
litellm.failure_callback = ["mlflow"]
litellm.completion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "what llm are u"}],
max_tokens=10,
temperature=0.2,
user="test-user",
)
@pytest.mark.asyncio()
async def test_async_mlflow_logging():
litellm.success_callback = ["mlflow"]
litellm.failure_callback = ["mlflow"]
await litellm.acompletion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "hi test from local arize"}],
mock_response="hello",
temperature=0.1,
user="OTEL_USER",
)