diff --git a/.circleci/config.yml b/.circleci/config.yml
index 7961cfddb..d95a8c214 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -690,6 +690,7 @@ jobs:
pip install "respx==0.21.1"
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"
+ pip install "mlflow==2.17.2"
# Run pytest and generate JUnit XML report
- run:
name: Run tests
diff --git a/README.md b/README.md
index 153d5ab3a..5d3efe355 100644
--- a/README.md
+++ b/README.md
@@ -113,7 +113,7 @@ for part in response:
## Logging Observability ([Docs](https://docs.litellm.ai/docs/observability/callbacks))
-LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack
+LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack, MLflow
```python
from litellm import completion
diff --git a/docs/my-website/docs/anthropic_completion.md b/docs/my-website/docs/anthropic_completion.md
deleted file mode 100644
index ca65f3f6f..000000000
--- a/docs/my-website/docs/anthropic_completion.md
+++ /dev/null
@@ -1,54 +0,0 @@
-# [BETA] Anthropic `/v1/messages`
-
-Call 100+ LLMs in the Anthropic format.
-
-
-1. Setup config.yaml
-
-```yaml
-model_list:
- - model_name: my-test-model
- litellm_params:
- model: gpt-3.5-turbo
-```
-
-2. Start proxy
-
-```bash
-litellm --config /path/to/config.yaml
-```
-
-3. Test it!
-
-```bash
-curl -X POST 'http://0.0.0.0:4000/v1/messages' \
--H 'x-api-key: sk-1234' \
--H 'content-type: application/json' \
--D '{
- "model": "my-test-model",
- "max_tokens": 1024,
- "messages": [
- {"role": "user", "content": "Hello, world"}
- ]
-}'
-```
-
-## Test with Anthropic SDK
-
-```python
-import os
-from anthropic import Anthropic
-
-client = Anthropic(api_key="sk-1234", base_url="http://0.0.0.0:4000") # 👈 CONNECT TO PROXY
-
-message = client.messages.create(
- messages=[
- {
- "role": "user",
- "content": "Hello, Claude",
- }
- ],
- model="my-test-model", # 👈 set 'model_name'
-)
-print(message.content)
-```
\ No newline at end of file
diff --git a/docs/my-website/docs/observability/mlflow.md b/docs/my-website/docs/observability/mlflow.md
new file mode 100644
index 000000000..3b1e1d477
--- /dev/null
+++ b/docs/my-website/docs/observability/mlflow.md
@@ -0,0 +1,108 @@
+# MLflow
+
+## What is MLflow?
+
+**MLflow** is an end-to-end open source MLOps platform for [experiment tracking](https://www.mlflow.org/docs/latest/tracking.html), [model management](https://www.mlflow.org/docs/latest/models.html), [evaluation](https://www.mlflow.org/docs/latest/llms/llm-evaluate/index.html), [observability (tracing)](https://www.mlflow.org/docs/latest/llms/tracing/index.html), and [deployment](https://www.mlflow.org/docs/latest/deployment/index.html). MLflow empowers teams to collaboratively develop and refine LLM applications efficiently.
+
+MLflow’s integration with LiteLLM supports advanced observability compatible with OpenTelemetry.
+
+
+
+
+
+## Getting Started
+
+Install MLflow:
+
+```shell
+pip install mlflow
+```
+
+To enable LiteLLM tracing:
+
+```python
+import mlflow
+
+mlflow.litellm.autolog()
+
+# Alternative, you can set the callback manually in LiteLLM
+# litellm.callbacks = ["mlflow"]
+```
+
+Since MLflow is open-source, no sign-up or API key is needed to log traces!
+
+```
+import litellm
+import os
+
+# Set your LLM provider's API key
+os.environ["OPENAI_API_KEY"] = ""
+
+# Call LiteLLM as usual
+response = litellm.completion(
+ model="gpt-4o-mini",
+ messages=[
+ {"role": "user", "content": "Hi 👋 - i'm openai"}
+ ]
+)
+```
+
+Open the MLflow UI and go to the `Traces` tab to view logged traces:
+
+```bash
+mlflow ui
+```
+
+## Exporting Traces to OpenTelemetry collectors
+
+MLflow traces are compatible with OpenTelemetry. You can export traces to any OpenTelemetry collector (e.g., Jaeger, Zipkin, Datadog, New Relic) by setting the endpoint URL in the environment variables.
+
+```
+# Set the endpoint of the OpenTelemetry Collector
+os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "http://localhost:4317/v1/traces"
+# Optionally, set the service name to group traces
+os.environ["OTEL_SERVICE_NAME"] = ""
+```
+
+See [MLflow documentation](https://mlflow.org/docs/latest/llms/tracing/index.html#using-opentelemetry-collector-for-exporting-traces) for more details.
+
+## Combine LiteLLM Trace with Your Application Trace
+
+LiteLLM is often part of larger LLM applications, such as agentic models. MLflow Tracing allows you to instrument custom Python code, which can then be combined with LiteLLM traces.
+
+```python
+import litellm
+import mlflow
+from mlflow.entities import SpanType
+
+# Enable LiteLLM tracing
+mlflow.litellm.autolog()
+
+
+class CustomAgent:
+ # Use @mlflow.trace to instrument Python functions.
+ @mlflow.trace(span_type=SpanType.AGENT)
+ def run(self, query: str):
+ # do something
+
+ while i < self.max_turns:
+ response = litellm.completion(
+ model="gpt-4o-mini",
+ messages=messages,
+ )
+
+ action = self.get_action(response)
+ ...
+
+ @mlflow.trace
+ def get_action(llm_response):
+ ...
+```
+
+This approach generates a unified trace, combining your custom Python code with LiteLLM calls.
+
+
+## Support
+
+* For advanced usage and integrations of tracing, visit the [MLflow Tracing documentation](https://mlflow.org/docs/latest/llms/tracing/index.html).
+* For any question or issue with this integration, please [submit an issue](https://github.com/mlflow/mlflow/issues/new/choose) on our [Github](https://github.com/mlflow/mlflow) repository!
\ No newline at end of file
diff --git a/docs/my-website/docs/pass_through/anthropic_completion.md b/docs/my-website/docs/pass_through/anthropic_completion.md
new file mode 100644
index 000000000..0c6a5f1b6
--- /dev/null
+++ b/docs/my-website/docs/pass_through/anthropic_completion.md
@@ -0,0 +1,282 @@
+# Anthropic `/v1/messages`
+
+Pass-through endpoints for Anthropic - call provider-specific endpoint, in native format (no translation).
+
+Just replace `https://api.anthropic.com` with `LITELLM_PROXY_BASE_URL/anthropic` 🚀
+
+#### **Example Usage**
+```bash
+curl --request POST \
+ --url http://0.0.0.0:4000/anthropic/v1/messages \
+ --header 'accept: application/json' \
+ --header 'content-type: application/json' \
+ --header "Authorization: bearer sk-anything" \
+ --data '{
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }'
+```
+
+Supports **ALL** Anthropic Endpoints (including streaming).
+
+[**See All Anthropic Endpoints**](https://docs.anthropic.com/en/api/messages)
+
+## Quick Start
+
+Let's call the Anthropic [`/messages` endpoint](https://docs.anthropic.com/en/api/messages)
+
+1. Add Anthropic API Key to your environment
+
+```bash
+export ANTHROPIC_API_KEY=""
+```
+
+2. Start LiteLLM Proxy
+
+```bash
+litellm
+
+# RUNNING on http://0.0.0.0:4000
+```
+
+3. Test it!
+
+Let's call the Anthropic /messages endpoint
+
+```bash
+curl http://0.0.0.0:4000/anthropic/v1/messages \
+ --header "x-api-key: $LITELLM_API_KEY" \
+ --header "anthropic-version: 2023-06-01" \
+ --header "content-type: application/json" \
+ --data \
+ '{
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }'
+```
+
+
+## Examples
+
+Anything after `http://0.0.0.0:4000/anthropic` is treated as a provider-specific route, and handled accordingly.
+
+Key Changes:
+
+| **Original Endpoint** | **Replace With** |
+|------------------------------------------------------|-----------------------------------|
+| `https://api.anthropic.com` | `http://0.0.0.0:4000/anthropic` (LITELLM_PROXY_BASE_URL="http://0.0.0.0:4000") |
+| `bearer $ANTHROPIC_API_KEY` | `bearer anything` (use `bearer LITELLM_VIRTUAL_KEY` if Virtual Keys are setup on proxy) |
+
+
+### **Example 1: Messages endpoint**
+
+#### LiteLLM Proxy Call
+
+```bash
+curl --request POST \
+ --url http://0.0.0.0:4000/anthropic/v1/messages \
+ --header "x-api-key: $LITELLM_API_KEY" \
+ --header "anthropic-version: 2023-06-01" \
+ --header "content-type: application/json" \
+ --data '{
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }'
+```
+
+#### Direct Anthropic API Call
+
+```bash
+curl https://api.anthropic.com/v1/messages \
+ --header "x-api-key: $ANTHROPIC_API_KEY" \
+ --header "anthropic-version: 2023-06-01" \
+ --header "content-type: application/json" \
+ --data \
+ '{
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }'
+```
+
+### **Example 2: Token Counting API**
+
+#### LiteLLM Proxy Call
+
+```bash
+curl --request POST \
+ --url http://0.0.0.0:4000/anthropic/v1/messages/count_tokens \
+ --header "x-api-key: $LITELLM_API_KEY" \
+ --header "anthropic-version: 2023-06-01" \
+ --header "anthropic-beta: token-counting-2024-11-01" \
+ --header "content-type: application/json" \
+ --data \
+ '{
+ "model": "claude-3-5-sonnet-20241022",
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }'
+```
+
+#### Direct Anthropic API Call
+
+```bash
+curl https://api.anthropic.com/v1/messages/count_tokens \
+ --header "x-api-key: $ANTHROPIC_API_KEY" \
+ --header "anthropic-version: 2023-06-01" \
+ --header "anthropic-beta: token-counting-2024-11-01" \
+ --header "content-type: application/json" \
+ --data \
+'{
+ "model": "claude-3-5-sonnet-20241022",
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+}'
+```
+
+### **Example 3: Batch Messages**
+
+
+#### LiteLLM Proxy Call
+
+```bash
+curl --request POST \
+ --url http://0.0.0.0:4000/anthropic/v1/messages/batches \
+ --header "x-api-key: $LITELLM_API_KEY" \
+ --header "anthropic-version: 2023-06-01" \
+ --header "anthropic-beta: message-batches-2024-09-24" \
+ --header "content-type: application/json" \
+ --data \
+'{
+ "requests": [
+ {
+ "custom_id": "my-first-request",
+ "params": {
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }
+ },
+ {
+ "custom_id": "my-second-request",
+ "params": {
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hi again, friend"}
+ ]
+ }
+ }
+ ]
+}'
+```
+
+#### Direct Anthropic API Call
+
+```bash
+curl https://api.anthropic.com/v1/messages/batches \
+ --header "x-api-key: $ANTHROPIC_API_KEY" \
+ --header "anthropic-version: 2023-06-01" \
+ --header "anthropic-beta: message-batches-2024-09-24" \
+ --header "content-type: application/json" \
+ --data \
+'{
+ "requests": [
+ {
+ "custom_id": "my-first-request",
+ "params": {
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }
+ },
+ {
+ "custom_id": "my-second-request",
+ "params": {
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hi again, friend"}
+ ]
+ }
+ }
+ ]
+}'
+```
+
+
+## Advanced - Use with Virtual Keys
+
+Pre-requisites
+- [Setup proxy with DB](../proxy/virtual_keys.md#setup)
+
+Use this, to avoid giving developers the raw Anthropic API key, but still letting them use Anthropic endpoints.
+
+### Usage
+
+1. Setup environment
+
+```bash
+export DATABASE_URL=""
+export LITELLM_MASTER_KEY=""
+export COHERE_API_KEY=""
+```
+
+```bash
+litellm
+
+# RUNNING on http://0.0.0.0:4000
+```
+
+2. Generate virtual key
+
+```bash
+curl -X POST 'http://0.0.0.0:4000/key/generate' \
+-H 'Authorization: Bearer sk-1234' \
+-H 'Content-Type: application/json' \
+-d '{}'
+```
+
+Expected Response
+
+```bash
+{
+ ...
+ "key": "sk-1234ewknldferwedojwojw"
+}
+```
+
+3. Test it!
+
+
+```bash
+curl --request POST \
+ --url http://0.0.0.0:4000/anthropic/v1/messages \
+ --header 'accept: application/json' \
+ --header 'content-type: application/json' \
+ --header "Authorization: bearer sk-1234ewknldferwedojwojw" \
+ --data '{
+ "model": "claude-3-5-sonnet-20241022",
+ "max_tokens": 1024,
+ "messages": [
+ {"role": "user", "content": "Hello, world"}
+ ]
+ }'
+```
\ No newline at end of file
diff --git a/docs/my-website/img/mlflow_tracing.png b/docs/my-website/img/mlflow_tracing.png
new file mode 100644
index 000000000..aee1fb79e
Binary files /dev/null and b/docs/my-website/img/mlflow_tracing.png differ
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index 1dc33f554..dd8443a28 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -65,12 +65,12 @@ const sidebars = {
},
{
type: "category",
- label: "Use with Provider SDKs",
+ label: "Pass-through Endpoints (Provider-specific)",
items: [
"pass_through/vertex_ai",
"pass_through/google_ai_studio",
"pass_through/cohere",
- "anthropic_completion",
+ "pass_through/anthropic_completion",
"pass_through/bedrock",
"pass_through/langfuse"
],
diff --git a/litellm/__init__.py b/litellm/__init__.py
index e8c3d6a64..edfe1a336 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -57,6 +57,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"gcs_bucket",
"opik",
"argilla",
+ "mlflow",
]
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list(
diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py
new file mode 100644
index 000000000..7268350d1
--- /dev/null
+++ b/litellm/integrations/mlflow.py
@@ -0,0 +1,247 @@
+import json
+import threading
+from typing import Optional
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+
+
+class MlflowLogger(CustomLogger):
+ def __init__(self):
+ from mlflow.tracking import MlflowClient
+
+ self._client = MlflowClient()
+
+ self._stream_id_to_span = {}
+ self._lock = threading.Lock() # lock for _stream_id_to_span
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_success(kwargs, response_obj, start_time, end_time)
+
+ def _handle_success(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the success event as an MLflow span.
+ Note that this method is called asynchronously in the background thread.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ try:
+ verbose_logger.debug("MLflow logging start for success event")
+
+ if kwargs.get("stream"):
+ self._handle_stream_event(kwargs, response_obj, start_time, end_time)
+ else:
+ span = self._start_span_or_trace(kwargs, start_time)
+ end_time_ns = int(end_time.timestamp() * 1e9)
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+ except Exception:
+ verbose_logger.debug("MLflow Logging Error", stack_info=True)
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ def _handle_failure(self, kwargs, response_obj, start_time, end_time):
+ """
+ Log the failure event as an MLflow span.
+ Note that this method is called *synchronously* unlike the success handler.
+ """
+ from mlflow.entities import SpanEvent, SpanStatusCode
+
+ try:
+ span = self._start_span_or_trace(kwargs, start_time)
+
+ end_time_ns = int(end_time.timestamp() * 1e9)
+
+ # Record exception info as event
+ if exception := kwargs.get("exception"):
+ span.add_event(SpanEvent.from_exception(exception))
+
+ self._end_span_or_trace(
+ span=span,
+ outputs=response_obj,
+ status=SpanStatusCode.ERROR,
+ end_time_ns=end_time_ns,
+ )
+
+ except Exception as e:
+ verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True)
+
+ def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Handle the success event for a streaming response. For streaming calls,
+ log_success_event handle is triggered for every chunk of the stream.
+ We create a single span for the entire stream request as follows:
+
+ 1. For the first chunk, start a new span and store it in the map.
+ 2. For subsequent chunks, add the chunk as an event to the span.
+ 3. For the final chunk, end the span and remove the span from the map.
+ """
+ from mlflow.entities import SpanStatusCode
+
+ litellm_call_id = kwargs.get("litellm_call_id")
+
+ if litellm_call_id not in self._stream_id_to_span:
+ with self._lock:
+ # Check again after acquiring lock
+ if litellm_call_id not in self._stream_id_to_span:
+ # Start a new span for the first chunk of the stream
+ span = self._start_span_or_trace(kwargs, start_time)
+ self._stream_id_to_span[litellm_call_id] = span
+
+ # Add chunk as event to the span
+ span = self._stream_id_to_span[litellm_call_id]
+ self._add_chunk_events(span, response_obj)
+
+ # If this is the final chunk, end the span. The final chunk
+ # has complete_streaming_response that gathers the full response.
+ if final_response := kwargs.get("complete_streaming_response"):
+ end_time_ns = int(end_time.timestamp() * 1e9)
+ self._end_span_or_trace(
+ span=span,
+ outputs=final_response,
+ status=SpanStatusCode.OK,
+ end_time_ns=end_time_ns,
+ )
+
+ # Remove the stream_id from the map
+ with self._lock:
+ self._stream_id_to_span.pop(litellm_call_id)
+
+ def _add_chunk_events(self, span, response_obj):
+ from mlflow.entities import SpanEvent
+
+ try:
+ for choice in response_obj.choices:
+ span.add_event(
+ SpanEvent(
+ name="streaming_chunk",
+ attributes={"delta": json.dumps(choice.delta.model_dump())},
+ )
+ )
+ except Exception:
+ verbose_logger.debug("Error adding chunk events to span", stack_info=True)
+
+ def _construct_input(self, kwargs):
+ """Construct span inputs with optional parameters"""
+ inputs = {"messages": kwargs.get("messages")}
+ for key in ["functions", "tools", "stream", "tool_choice", "user"]:
+ if value := kwargs.get("optional_params", {}).pop(key, None):
+ inputs[key] = value
+ return inputs
+
+ def _extract_attributes(self, kwargs):
+ """
+ Extract span attributes from kwargs.
+
+ With the latest version of litellm, the standard_logging_object contains
+ canonical information for logging. If it is not present, we extract
+ subset of attributes from other kwargs.
+ """
+ attributes = {
+ "litellm_call_id": kwargs.get("litellm_call_id"),
+ "call_type": kwargs.get("call_type"),
+ "model": kwargs.get("model"),
+ }
+ standard_obj = kwargs.get("standard_logging_object")
+ if standard_obj:
+ attributes.update(
+ {
+ "api_base": standard_obj.get("api_base"),
+ "cache_hit": standard_obj.get("cache_hit"),
+ "usage": {
+ "completion_tokens": standard_obj.get("completion_tokens"),
+ "prompt_tokens": standard_obj.get("prompt_tokens"),
+ "total_tokens": standard_obj.get("total_tokens"),
+ },
+ "raw_llm_response": standard_obj.get("response"),
+ "response_cost": standard_obj.get("response_cost"),
+ "saved_cache_cost": standard_obj.get("saved_cache_cost"),
+ }
+ )
+ else:
+ litellm_params = kwargs.get("litellm_params", {})
+ attributes.update(
+ {
+ "model": kwargs.get("model"),
+ "cache_hit": kwargs.get("cache_hit"),
+ "custom_llm_provider": kwargs.get("custom_llm_provider"),
+ "api_base": litellm_params.get("api_base"),
+ "response_cost": kwargs.get("response_cost"),
+ }
+ )
+ return attributes
+
+ def _get_span_type(self, call_type: Optional[str]) -> str:
+ from mlflow.entities import SpanType
+
+ if call_type in ["completion", "acompletion"]:
+ return SpanType.LLM
+ elif call_type == "embeddings":
+ return SpanType.EMBEDDING
+ else:
+ return SpanType.LLM
+
+ def _start_span_or_trace(self, kwargs, start_time):
+ """
+ Start an MLflow span or a trace.
+
+ If there is an active span, we start a new span as a child of
+ that span. Otherwise, we start a new trace.
+ """
+ import mlflow
+
+ call_type = kwargs.get("call_type", "completion")
+ span_name = f"litellm-{call_type}"
+ span_type = self._get_span_type(call_type)
+ start_time_ns = int(start_time.timestamp() * 1e9)
+
+ inputs = self._construct_input(kwargs)
+ attributes = self._extract_attributes(kwargs)
+
+ if active_span := mlflow.get_current_active_span(): # type: ignore
+ return self._client.start_span(
+ name=span_name,
+ request_id=active_span.request_id,
+ parent_id=active_span.span_id,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+ else:
+ return self._client.start_trace(
+ name=span_name,
+ span_type=span_type,
+ inputs=inputs,
+ attributes=attributes,
+ start_time_ns=start_time_ns,
+ )
+
+ def _end_span_or_trace(self, span, outputs, end_time_ns, status):
+ """End an MLflow span or a trace."""
+ if span.parent_id is None:
+ self._client.end_trace(
+ request_id=span.request_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )
+ else:
+ self._client.end_span(
+ request_id=span.request_id,
+ span_id=span.span_id,
+ outputs=outputs,
+ status=status,
+ end_time_ns=end_time_ns,
+ )
diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py
index bb94d54d5..05b4b9c48 100644
--- a/litellm/litellm_core_utils/get_supported_openai_params.py
+++ b/litellm/litellm_core_utils/get_supported_openai_params.py
@@ -161,17 +161,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai":
- return [
- "stream",
- "temperature",
- "max_tokens",
- "top_p",
- "stop",
- "frequency_penalty",
- "tools",
- "tool_choice",
- "response_format",
- ]
+ return litellm.TogetherAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ai21":
return [
"stream",
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index 15f7f59fa..69d6adca4 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -28,6 +28,7 @@ from litellm.caching.caching_handler import LLMCachingHandler
from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
+from litellm.integrations.mlflow import MlflowLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging,
@@ -563,6 +564,7 @@ class Logging:
message=f"Model Call Details pre-call: {details_to_log}",
level="info",
)
+
elif isinstance(callback, CustomLogger): # custom logger class
callback.log_pre_api_call(
model=self.model,
@@ -1258,6 +1260,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
+
if (
callback == "openmeter"
and self.model_call_details.get("litellm_params", {}).get(
@@ -2347,6 +2350,14 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
_in_memory_loggers.append(_otel_logger)
return _otel_logger # type: ignore
+ elif logging_integration == "mlflow":
+ for callback in _in_memory_loggers:
+ if isinstance(callback, MlflowLogger):
+ return callback # type: ignore
+
+ _mlflow_logger = MlflowLogger()
+ _in_memory_loggers.append(_mlflow_logger)
+ return _mlflow_logger # type: ignore
def get_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
@@ -2448,6 +2459,12 @@ def get_custom_logger_compatible_class(
and callback.callback_name == "langtrace"
):
return callback
+
+ elif logging_integration == "mlflow":
+ for callback in _in_memory_loggers:
+ if isinstance(callback, MlflowLogger):
+ return callback
+
return None
diff --git a/litellm/llms/together_ai/chat.py b/litellm/llms/together_ai/chat.py
index 398bc489c..cb12d6147 100644
--- a/litellm/llms/together_ai/chat.py
+++ b/litellm/llms/together_ai/chat.py
@@ -6,8 +6,8 @@ Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
Docs: https://docs.together.ai/reference/completions-1
"""
-from ..OpenAI.openai import OpenAIConfig
+from ..OpenAI.chat.gpt_transformation import OpenAIGPTConfig
-class TogetherAIConfig(OpenAIConfig):
+class TogetherAIConfig(OpenAIGPTConfig):
pass
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 34ac51481..92ca32e52 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -1069,7 +1069,7 @@ async def update_cache( # noqa: PLR0915
end_user_id: Optional[str],
team_id: Optional[str],
response_cost: Optional[float],
- parent_otel_span: Optional[Span],
+ parent_otel_span: Optional[Span], # type: ignore
):
"""
Use this to update the cache with new user spend.
@@ -5657,6 +5657,13 @@ async def anthropic_response( # noqa: PLR0915
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
+ """
+ This is a BETA endpoint that calls 100+ LLMs in the anthropic format.
+
+ To do a simple pass-through for anthropic, do `{PROXY_BASE_URL}/anthropic/v1/messages`
+
+ Docs - https://docs.litellm.ai/docs/anthropic_completion
+ """
from litellm import adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py
index f6d36daaf..e0fa1e092 100644
--- a/litellm/proxy/spend_tracking/spend_management_endpoints.py
+++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py
@@ -9,6 +9,9 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
+from litellm.proxy.spend_tracking.spend_tracking_utils import (
+ get_spend_by_team_and_customer,
+)
router = APIRouter()
@@ -932,6 +935,14 @@ async def get_global_spend_report(
default=None,
description="View spend for a specific internal_user_id. Example internal_user_id='1234",
),
+ team_id: Optional[str] = fastapi.Query(
+ default=None,
+ description="View spend for a specific team_id. Example team_id='1234",
+ ),
+ customer_id: Optional[str] = fastapi.Query(
+ default=None,
+ description="View spend for a specific customer_id. Example customer_id='1234. Can be used in conjunction with team_id as well.",
+ ),
):
"""
Get Daily Spend per Team, based on specific startTime and endTime. Per team, view usage by each key, model
@@ -1074,8 +1085,12 @@ async def get_global_spend_report(
return []
return db_response
-
+ elif team_id is not None and customer_id is not None:
+ return await get_spend_by_team_and_customer(
+ start_date_obj, end_date_obj, team_id, customer_id, prisma_client
+ )
if group_by == "team":
+
# first get data from spend logs -> SpendByModelApiKey
# then read data from "SpendByModelApiKey" to format the response obj
sql_query = """
@@ -1305,7 +1320,6 @@ async def global_get_all_tag_names():
"/global/spend/tags",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
- include_in_schema=False,
responses={
200: {"model": List[LiteLLM_SpendLogs]},
},
diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py
index 30e3ae5cd..48924d521 100644
--- a/litellm/proxy/spend_tracking/spend_tracking_utils.py
+++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py
@@ -1,7 +1,9 @@
+import datetime
import json
import os
import secrets
import traceback
+from datetime import datetime as dt
from typing import Optional
from pydantic import BaseModel
@@ -9,7 +11,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
-from litellm.proxy.utils import hash_token
+from litellm.proxy.utils import PrismaClient, hash_token
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
@@ -163,3 +165,79 @@ def get_logging_payload(
"Error creating spendlogs object - {}".format(str(e))
)
raise e
+
+
+async def get_spend_by_team_and_customer(
+ start_date: dt,
+ end_date: dt,
+ team_id: str,
+ customer_id: str,
+ prisma_client: PrismaClient,
+):
+ sql_query = """
+ WITH SpendByModelApiKey AS (
+ SELECT
+ date_trunc('day', sl."startTime") AS group_by_day,
+ COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
+ sl.end_user AS customer,
+ sl.model,
+ sl.api_key,
+ SUM(sl.spend) AS model_api_spend,
+ SUM(sl.total_tokens) AS model_api_tokens
+ FROM
+ "LiteLLM_SpendLogs" sl
+ LEFT JOIN
+ "LiteLLM_TeamTable" tt
+ ON
+ sl.team_id = tt.team_id
+ WHERE
+ sl."startTime" BETWEEN $1::date AND $2::date
+ AND sl.team_id = $3
+ AND sl.end_user = $4
+ GROUP BY
+ date_trunc('day', sl."startTime"),
+ tt.team_alias,
+ sl.end_user,
+ sl.model,
+ sl.api_key
+ )
+ SELECT
+ group_by_day,
+ jsonb_agg(jsonb_build_object(
+ 'team_name', team_name,
+ 'customer', customer,
+ 'total_spend', total_spend,
+ 'metadata', metadata
+ )) AS teams_customers
+ FROM (
+ SELECT
+ group_by_day,
+ team_name,
+ customer,
+ SUM(model_api_spend) AS total_spend,
+ jsonb_agg(jsonb_build_object(
+ 'model', model,
+ 'api_key', api_key,
+ 'spend', model_api_spend,
+ 'total_tokens', model_api_tokens
+ )) AS metadata
+ FROM
+ SpendByModelApiKey
+ GROUP BY
+ group_by_day,
+ team_name,
+ customer
+ ) AS aggregated
+ GROUP BY
+ group_by_day
+ ORDER BY
+ group_by_day;
+ """
+
+ db_response = await prisma_client.db.query_raw(
+ sql_query, start_date, end_date, team_id, customer_id
+ )
+ if db_response is None:
+ return []
+
+ return db_response
diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py
index 667a21a3c..c4a64fa21 100644
--- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py
+++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py
@@ -155,6 +155,51 @@ async def cohere_proxy_route(
return received_value
+@router.api_route(
+ "/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
+)
+async def anthropic_proxy_route(
+ endpoint: str,
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ base_target_url = "https://api.anthropic.com"
+ encoded_endpoint = httpx.URL(endpoint).path
+
+ # Ensure endpoint starts with '/' for proper URL construction
+ if not encoded_endpoint.startswith("/"):
+ encoded_endpoint = "/" + encoded_endpoint
+
+ # Construct the full target URL using httpx
+ base_url = httpx.URL(base_target_url)
+ updated_url = base_url.copy_with(path=encoded_endpoint)
+
+ # Add or update query parameters
+ anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY")
+
+ ## check for streaming
+ is_streaming_request = False
+ if "stream" in str(updated_url):
+ is_streaming_request = True
+
+ ## CREATE PASS-THROUGH
+ endpoint_func = create_pass_through_route(
+ endpoint=endpoint,
+ target=str(updated_url),
+ custom_headers={"x-api-key": "{}".format(anthropic_api_key)},
+ _forward_headers=True,
+ ) # dynamically construct pass-through endpoint based on incoming path
+ received_value = await endpoint_func(
+ request,
+ fastapi_response,
+ user_api_key_dict,
+ stream=is_streaming_request, # type: ignore
+ )
+
+ return received_value
+
+
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def bedrock_proxy_route(
endpoint: str,
diff --git a/litellm/tests/test_mlflow.py b/litellm/tests/test_mlflow.py
new file mode 100644
index 000000000..ec23875ea
--- /dev/null
+++ b/litellm/tests/test_mlflow.py
@@ -0,0 +1,29 @@
+import pytest
+
+import litellm
+
+
+def test_mlflow_logging():
+ litellm.success_callback = ["mlflow"]
+ litellm.failure_callback = ["mlflow"]
+
+ litellm.completion(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": "what llm are u"}],
+ max_tokens=10,
+ temperature=0.2,
+ user="test-user",
+ )
+
+@pytest.mark.asyncio()
+async def test_async_mlflow_logging():
+ litellm.success_callback = ["mlflow"]
+ litellm.failure_callback = ["mlflow"]
+
+ await litellm.acompletion(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": "hi test from local arize"}],
+ mock_response="hello",
+ temperature=0.1,
+ user="OTEL_USER",
+ )
diff --git a/litellm/utils.py b/litellm/utils.py
index fdb533e4e..f4f31e6cf 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -2903,24 +2903,16 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
- if stream:
- optional_params["stream"] = stream
- if temperature is not None:
- optional_params["temperature"] = temperature
- if top_p is not None:
- optional_params["top_p"] = top_p
- if max_tokens is not None:
- optional_params["max_tokens"] = max_tokens
- if frequency_penalty is not None:
- optional_params["frequency_penalty"] = frequency_penalty
- if stop is not None:
- optional_params["stop"] = stop
- if tools is not None:
- optional_params["tools"] = tools
- if tool_choice is not None:
- optional_params["tool_choice"] = tool_choice
- if response_format is not None:
- optional_params["response_format"] = response_format
+ optional_params = litellm.TogetherAIConfig().map_openai_params(
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=(
+ drop_params
+ if drop_params is not None and isinstance(drop_params, bool)
+ else False
+ ),
+ )
elif custom_llm_provider == "ai21":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py
index bea066865..c9527c830 100644
--- a/tests/llm_translation/test_optional_params.py
+++ b/tests/llm_translation/test_optional_params.py
@@ -923,6 +923,14 @@ def test_watsonx_text_top_k():
assert optional_params["top_k"] == 10
+
+def test_together_ai_model_params():
+ optional_params = get_optional_params(
+ model="together_ai", custom_llm_provider="together_ai", logprobs=1
+ )
+ print(optional_params)
+ assert optional_params["logprobs"] == 1
+
def test_forward_user_param():
from litellm.utils import get_supported_openai_params, get_optional_params
diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py
index 881185b74..3ce4cb7d7 100644
--- a/tests/local_testing/test_completion.py
+++ b/tests/local_testing/test_completion.py
@@ -406,8 +406,13 @@ def test_completion_claude_3_empty_response():
"content": "I was hoping we could chat a bit",
},
]
- response = litellm.completion(model="claude-3-opus-20240229", messages=messages)
- print(response)
+ try:
+ response = litellm.completion(model="claude-3-opus-20240229", messages=messages)
+ print(response)
+ except litellm.InternalServerError as e:
+ pytest.skip(f"InternalServerError - {str(e)}")
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3():
@@ -434,6 +439,8 @@ def test_completion_claude_3():
)
# Add any assertions, here to check response args
print(response)
+ except litellm.InternalServerError as e:
+ pytest.skip(f"InternalServerError - {str(e)}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@@ -917,6 +924,9 @@ def test_completion_base64(model):
except litellm.ServiceUnavailableError as e:
print("got service unavailable error: ", e)
pass
+ except litellm.InternalServerError as e:
+ print("got internal server error: ", e)
+ pass
except Exception as e:
if "500 Internal error encountered.'" in str(e):
pass
@@ -1055,7 +1065,6 @@ def test_completion_mistral_api():
cost = litellm.completion_cost(completion_response=response)
print("cost to make mistral completion=", cost)
assert cost > 0.0
- assert response.model == "mistral/mistral-tiny"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py
index 209b38423..0bc6953f9 100644
--- a/tests/local_testing/test_streaming.py
+++ b/tests/local_testing/test_streaming.py
@@ -3333,8 +3333,8 @@ async def test_acompletion_function_call_with_streaming(model):
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked! ")
- except litellm.InternalServerError:
- pass
+ except litellm.InternalServerError as e:
+ pytest.skip(f"InternalServerError - {str(e)}")
except litellm.ServiceUnavailableError:
pass
except Exception as e:
diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py
index ffc58416d..ecfc305f9 100644
--- a/tests/logging_callback_tests/test_otel_logging.py
+++ b/tests/logging_callback_tests/test_otel_logging.py
@@ -144,6 +144,7 @@ def validate_raw_gen_ai_request_openai_streaming(span):
"model",
["anthropic/claude-3-opus-20240229"],
)
+@pytest.mark.flaky(retries=6, delay=2)
def test_completion_claude_3_function_call_with_otel(model):
litellm.set_verbose = True
diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py
index ebc7dd33c..38883fa38 100644
--- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py
+++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py
@@ -31,6 +31,7 @@ from litellm.integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
from litellm.integrations.opik.opik import OpikLogger
from litellm.integrations.opentelemetry import OpenTelemetry
+from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.argilla import ArgillaLogger
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
from unittest.mock import patch
@@ -59,6 +60,7 @@ callback_class_str_to_classType = {
"logfire": OpenTelemetry,
"arize": OpenTelemetry,
"langtrace": OpenTelemetry,
+ "mlflow": MlflowLogger,
}
expected_env_vars = {