diff --git a/README.md b/README.md index e13732000..bfdba2fa3 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ for part in response: ## Logging Observability ([Docs](https://docs.litellm.ai/docs/observability/callbacks)) -LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack +LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack, MLflow ```python from litellm import completion diff --git a/docs/my-website/docs/observability/mlflow.md b/docs/my-website/docs/observability/mlflow.md new file mode 100644 index 000000000..3b1e1d477 --- /dev/null +++ b/docs/my-website/docs/observability/mlflow.md @@ -0,0 +1,108 @@ +# MLflow + +## What is MLflow? + +**MLflow** is an end-to-end open source MLOps platform for [experiment tracking](https://www.mlflow.org/docs/latest/tracking.html), [model management](https://www.mlflow.org/docs/latest/models.html), [evaluation](https://www.mlflow.org/docs/latest/llms/llm-evaluate/index.html), [observability (tracing)](https://www.mlflow.org/docs/latest/llms/tracing/index.html), and [deployment](https://www.mlflow.org/docs/latest/deployment/index.html). MLflow empowers teams to collaboratively develop and refine LLM applications efficiently. + +MLflow’s integration with LiteLLM supports advanced observability compatible with OpenTelemetry. + + + + + +## Getting Started + +Install MLflow: + +```shell +pip install mlflow +``` + +To enable LiteLLM tracing: + +```python +import mlflow + +mlflow.litellm.autolog() + +# Alternative, you can set the callback manually in LiteLLM +# litellm.callbacks = ["mlflow"] +``` + +Since MLflow is open-source, no sign-up or API key is needed to log traces! + +``` +import litellm +import os + +# Set your LLM provider's API key +os.environ["OPENAI_API_KEY"] = "" + +# Call LiteLLM as usual +response = litellm.completion( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": "Hi 👋 - i'm openai"} + ] +) +``` + +Open the MLflow UI and go to the `Traces` tab to view logged traces: + +```bash +mlflow ui +``` + +## Exporting Traces to OpenTelemetry collectors + +MLflow traces are compatible with OpenTelemetry. You can export traces to any OpenTelemetry collector (e.g., Jaeger, Zipkin, Datadog, New Relic) by setting the endpoint URL in the environment variables. + +``` +# Set the endpoint of the OpenTelemetry Collector +os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "http://localhost:4317/v1/traces" +# Optionally, set the service name to group traces +os.environ["OTEL_SERVICE_NAME"] = "" +``` + +See [MLflow documentation](https://mlflow.org/docs/latest/llms/tracing/index.html#using-opentelemetry-collector-for-exporting-traces) for more details. + +## Combine LiteLLM Trace with Your Application Trace + +LiteLLM is often part of larger LLM applications, such as agentic models. MLflow Tracing allows you to instrument custom Python code, which can then be combined with LiteLLM traces. + +```python +import litellm +import mlflow +from mlflow.entities import SpanType + +# Enable LiteLLM tracing +mlflow.litellm.autolog() + + +class CustomAgent: + # Use @mlflow.trace to instrument Python functions. + @mlflow.trace(span_type=SpanType.AGENT) + def run(self, query: str): + # do something + + while i < self.max_turns: + response = litellm.completion( + model="gpt-4o-mini", + messages=messages, + ) + + action = self.get_action(response) + ... + + @mlflow.trace + def get_action(llm_response): + ... +``` + +This approach generates a unified trace, combining your custom Python code with LiteLLM calls. + + +## Support + +* For advanced usage and integrations of tracing, visit the [MLflow Tracing documentation](https://mlflow.org/docs/latest/llms/tracing/index.html). +* For any question or issue with this integration, please [submit an issue](https://github.com/mlflow/mlflow/issues/new/choose) on our [Github](https://github.com/mlflow/mlflow) repository! \ No newline at end of file diff --git a/docs/my-website/img/mlflow_tracing.png b/docs/my-website/img/mlflow_tracing.png new file mode 100644 index 000000000..aee1fb79e Binary files /dev/null and b/docs/my-website/img/mlflow_tracing.png differ diff --git a/litellm/__init__.py b/litellm/__init__.py index 9812de1d8..e54117e11 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -57,6 +57,7 @@ _custom_logger_compatible_callbacks_literal = Literal[ "gcs_bucket", "opik", "argilla", + "mlflow", ] logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None _known_custom_logger_compatible_callbacks: List = list( diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py new file mode 100644 index 000000000..baf33a86b --- /dev/null +++ b/litellm/integrations/mlflow.py @@ -0,0 +1,246 @@ +import json +import threading +from typing import Optional + +from litellm.integrations.custom_logger import CustomLogger +from litellm._logging import verbose_logger + +class MlflowLogger(CustomLogger): + def __init__(self): + from mlflow.tracking import MlflowClient + + self._client = MlflowClient() + + self._stream_id_to_span = {} + self._lock = threading.Lock() # lock for _stream_id_to_span + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_success(kwargs, response_obj, start_time, end_time) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_success(kwargs, response_obj, start_time, end_time) + + def _handle_success(self, kwargs, response_obj, start_time, end_time): + """ + Log the success event as an MLflow span. + Note that this method is called asynchronously in the background thread. + """ + from mlflow.entities import SpanStatusCode + + try: + verbose_logger.debug(f"MLflow logging start for success event") + + if kwargs.get("stream"): + self._handle_stream_event(kwargs, response_obj, start_time, end_time) + else: + span = self._start_span_or_trace(kwargs, start_time) + end_time_ns = int(end_time.timestamp() * 1e9) + self._end_span_or_trace( + span=span, + outputs=response_obj, + status=SpanStatusCode.OK, + end_time_ns=end_time_ns, + ) + except Exception: + verbose_logger.debug(f"MLflow Logging Error", stack_info=True) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + def _handle_failure(self, kwargs, response_obj, start_time, end_time): + """ + Log the failure event as an MLflow span. + Note that this method is called *synchronously* unlike the success handler. + """ + from mlflow.entities import SpanEvent, SpanStatusCode + + try: + span = self._start_span_or_trace(kwargs, start_time) + + end_time_ns = int(end_time.timestamp() * 1e9) + + # Record exception info as event + if exception := kwargs.get("exception"): + span.add_event(SpanEvent.from_exception(exception)) + + self._end_span_or_trace( + span=span, + outputs=response_obj, + status=SpanStatusCode.ERROR, + end_time_ns=end_time_ns, + ) + + except Exception as e: + verbose_logger.debug(f"MLflow Logging Error", stack_info=True) + + def _handle_stream_event(self, kwargs, response_obj, start_time, end_time): + """ + Handle the success event for a streaming response. For streaming calls, + log_success_event handle is triggered for every chunk of the stream. + We create a single span for the entire stream request as follows: + + 1. For the first chunk, start a new span and store it in the map. + 2. For subsequent chunks, add the chunk as an event to the span. + 3. For the final chunk, end the span and remove the span from the map. + """ + from mlflow.entities import SpanStatusCode + + litellm_call_id = kwargs.get("litellm_call_id") + + if litellm_call_id not in self._stream_id_to_span: + with self._lock: + # Check again after acquiring lock + if litellm_call_id not in self._stream_id_to_span: + # Start a new span for the first chunk of the stream + span = self._start_span_or_trace(kwargs, start_time) + self._stream_id_to_span[litellm_call_id] = span + + # Add chunk as event to the span + span = self._stream_id_to_span[litellm_call_id] + self._add_chunk_events(span, response_obj) + + # If this is the final chunk, end the span. The final chunk + # has complete_streaming_response that gathers the full response. + if final_response := kwargs.get("complete_streaming_response"): + end_time_ns = int(end_time.timestamp() * 1e9) + self._end_span_or_trace( + span=span, + outputs=final_response, + status=SpanStatusCode.OK, + end_time_ns=end_time_ns, + ) + + # Remove the stream_id from the map + with self._lock: + self._stream_id_to_span.pop(litellm_call_id) + + def _add_chunk_events(self, span, response_obj): + from mlflow.entities import SpanEvent + + try: + for choice in response_obj.choices: + span.add_event( + SpanEvent( + name="streaming_chunk", + attributes={"delta": json.dumps(choice.delta.model_dump())}, + ) + ) + except Exception: + verbose_logger.debug("Error adding chunk events to span", stack_info=True) + + def _construct_input(self, kwargs): + """Construct span inputs with optional parameters""" + inputs = {"messages": kwargs.get("messages")} + for key in ["functions", "tools", "stream", "tool_choice", "user"]: + if value := kwargs.get("optional_params", {}).pop(key, None): + inputs[key] = value + return inputs + + def _extract_attributes(self, kwargs): + """ + Extract span attributes from kwargs. + + With the latest version of litellm, the standard_logging_object contains + canonical information for logging. If it is not present, we extract + subset of attributes from other kwargs. + """ + attributes = { + "litellm_call_id": kwargs.get("litellm_call_id"), + "call_type": kwargs.get("call_type"), + "model": kwargs.get("model"), + } + standard_obj = kwargs.get("standard_logging_object") + if standard_obj: + attributes.update( + { + "api_base": standard_obj.get("api_base"), + "cache_hit": standard_obj.get("cache_hit"), + "usage": { + "completion_tokens": standard_obj.get("completion_tokens"), + "prompt_tokens": standard_obj.get("prompt_tokens"), + "total_tokens": standard_obj.get("total_tokens"), + }, + "raw_llm_response": standard_obj.get("response"), + "response_cost": standard_obj.get("response_cost"), + "saved_cache_cost": standard_obj.get("saved_cache_cost"), + } + ) + else: + litellm_params = kwargs.get("litellm_params", {}) + attributes.update( + { + "model": kwargs.get("model"), + "cache_hit": kwargs.get("cache_hit"), + "custom_llm_provider": kwargs.get("custom_llm_provider"), + "api_base": litellm_params.get("api_base"), + "response_cost": kwargs.get("response_cost"), + } + ) + return attributes + + def _get_span_type(self, call_type: Optional[str]) -> str: + from mlflow.entities import SpanType + + if call_type in ["completion", "acompletion"]: + return SpanType.LLM + elif call_type == "embeddings": + return SpanType.EMBEDDING + else: + return SpanType.LLM + + def _start_span_or_trace(self, kwargs, start_time): + """ + Start an MLflow span or a trace. + + If there is an active span, we start a new span as a child of + that span. Otherwise, we start a new trace. + """ + import mlflow + + call_type = kwargs.get("call_type", "completion") + span_name = f"litellm-{call_type}" + span_type = self._get_span_type(call_type) + start_time_ns = int(start_time.timestamp() * 1e9) + + inputs = self._construct_input(kwargs) + attributes = self._extract_attributes(kwargs) + + if active_span := mlflow.get_current_active_span(): + return self._client.start_span( + name=span_name, + request_id=active_span.request_id, + parent_id=active_span.span_id, + span_type=span_type, + inputs=inputs, + attributes=attributes, + start_time_ns=start_time_ns, + ) + else: + return self._client.start_trace( + name=span_name, + span_type=span_type, + inputs=inputs, + attributes=attributes, + start_time_ns=start_time_ns, + ) + + def _end_span_or_trace(self, span, outputs, end_time_ns, status): + """End an MLflow span or a trace.""" + if span.parent_id is None: + self._client.end_trace( + request_id=span.request_id, + outputs=outputs, + status=status, + end_time_ns=end_time_ns, + ) + else: + self._client.end_span( + request_id=span.request_id, + span_id=span.span_id, + outputs=outputs, + status=status, + end_time_ns=end_time_ns, + ) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index d2e65742c..66f91abf1 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -28,6 +28,7 @@ from litellm.caching.caching_handler import LLMCachingHandler from litellm.cost_calculator import _select_model_name_for_cost_calc from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.mlflow import MlflowLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_custom_logger, redact_message_input_output_from_logging, @@ -554,6 +555,7 @@ class Logging: message=f"Model Call Details pre-call: {details_to_log}", level="info", ) + elif isinstance(callback, CustomLogger): # custom logger class callback.log_pre_api_call( model=self.model, @@ -1249,6 +1251,7 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) + if ( callback == "openmeter" and self.model_call_details.get("litellm_params", {}).get( @@ -2338,6 +2341,14 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_otel_logger) return _otel_logger # type: ignore + elif logging_integration == "mlflow": + for callback in _in_memory_loggers: + if isinstance(callback, MlflowLogger): + return callback # type: ignore + + _mlflow_logger = MlflowLogger() + _in_memory_loggers.append(_mlflow_logger) + return _mlflow_logger # type: ignore def get_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, @@ -2439,6 +2450,12 @@ def get_custom_logger_compatible_class( and callback.callback_name == "langtrace" ): return callback + + elif logging_integration == "mlflow": + for callback in _in_memory_loggers: + if isinstance(callback, MlflowLogger): + return callback + return None diff --git a/litellm/tests/test_mlflow.py b/litellm/tests/test_mlflow.py new file mode 100644 index 000000000..ec23875ea --- /dev/null +++ b/litellm/tests/test_mlflow.py @@ -0,0 +1,29 @@ +import pytest + +import litellm + + +def test_mlflow_logging(): + litellm.success_callback = ["mlflow"] + litellm.failure_callback = ["mlflow"] + + litellm.completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "what llm are u"}], + max_tokens=10, + temperature=0.2, + user="test-user", + ) + +@pytest.mark.asyncio() +async def test_async_mlflow_logging(): + litellm.success_callback = ["mlflow"] + litellm.failure_callback = ["mlflow"] + + await litellm.acompletion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi test from local arize"}], + mock_response="hello", + temperature=0.1, + user="OTEL_USER", + )