diff --git a/.circleci/config.yml b/.circleci/config.yml index 7961cfddb..d95a8c214 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -690,6 +690,7 @@ jobs: pip install "respx==0.21.1" pip install "google-generativeai==0.3.2" pip install "google-cloud-aiplatform==1.43.0" + pip install "mlflow==2.17.2" # Run pytest and generate JUnit XML report - run: name: Run tests diff --git a/README.md b/README.md index 153d5ab3a..5d3efe355 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ for part in response: ## Logging Observability ([Docs](https://docs.litellm.ai/docs/observability/callbacks)) -LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack +LiteLLM exposes pre defined callbacks to send data to Lunary, Langfuse, DynamoDB, s3 Buckets, Helicone, Promptlayer, Traceloop, Athina, Slack, MLflow ```python from litellm import completion diff --git a/docs/my-website/docs/anthropic_completion.md b/docs/my-website/docs/anthropic_completion.md deleted file mode 100644 index ca65f3f6f..000000000 --- a/docs/my-website/docs/anthropic_completion.md +++ /dev/null @@ -1,54 +0,0 @@ -# [BETA] Anthropic `/v1/messages` - -Call 100+ LLMs in the Anthropic format. - - -1. Setup config.yaml - -```yaml -model_list: - - model_name: my-test-model - litellm_params: - model: gpt-3.5-turbo -``` - -2. Start proxy - -```bash -litellm --config /path/to/config.yaml -``` - -3. Test it! - -```bash -curl -X POST 'http://0.0.0.0:4000/v1/messages' \ --H 'x-api-key: sk-1234' \ --H 'content-type: application/json' \ --D '{ - "model": "my-test-model", - "max_tokens": 1024, - "messages": [ - {"role": "user", "content": "Hello, world"} - ] -}' -``` - -## Test with Anthropic SDK - -```python -import os -from anthropic import Anthropic - -client = Anthropic(api_key="sk-1234", base_url="http://0.0.0.0:4000") # 👈 CONNECT TO PROXY - -message = client.messages.create( - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], - model="my-test-model", # 👈 set 'model_name' -) -print(message.content) -``` \ No newline at end of file diff --git a/docs/my-website/docs/observability/mlflow.md b/docs/my-website/docs/observability/mlflow.md new file mode 100644 index 000000000..3b1e1d477 --- /dev/null +++ b/docs/my-website/docs/observability/mlflow.md @@ -0,0 +1,108 @@ +# MLflow + +## What is MLflow? + +**MLflow** is an end-to-end open source MLOps platform for [experiment tracking](https://www.mlflow.org/docs/latest/tracking.html), [model management](https://www.mlflow.org/docs/latest/models.html), [evaluation](https://www.mlflow.org/docs/latest/llms/llm-evaluate/index.html), [observability (tracing)](https://www.mlflow.org/docs/latest/llms/tracing/index.html), and [deployment](https://www.mlflow.org/docs/latest/deployment/index.html). MLflow empowers teams to collaboratively develop and refine LLM applications efficiently. + +MLflow’s integration with LiteLLM supports advanced observability compatible with OpenTelemetry. + + + + + +## Getting Started + +Install MLflow: + +```shell +pip install mlflow +``` + +To enable LiteLLM tracing: + +```python +import mlflow + +mlflow.litellm.autolog() + +# Alternative, you can set the callback manually in LiteLLM +# litellm.callbacks = ["mlflow"] +``` + +Since MLflow is open-source, no sign-up or API key is needed to log traces! + +``` +import litellm +import os + +# Set your LLM provider's API key +os.environ["OPENAI_API_KEY"] = "" + +# Call LiteLLM as usual +response = litellm.completion( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": "Hi 👋 - i'm openai"} + ] +) +``` + +Open the MLflow UI and go to the `Traces` tab to view logged traces: + +```bash +mlflow ui +``` + +## Exporting Traces to OpenTelemetry collectors + +MLflow traces are compatible with OpenTelemetry. You can export traces to any OpenTelemetry collector (e.g., Jaeger, Zipkin, Datadog, New Relic) by setting the endpoint URL in the environment variables. + +``` +# Set the endpoint of the OpenTelemetry Collector +os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "http://localhost:4317/v1/traces" +# Optionally, set the service name to group traces +os.environ["OTEL_SERVICE_NAME"] = "" +``` + +See [MLflow documentation](https://mlflow.org/docs/latest/llms/tracing/index.html#using-opentelemetry-collector-for-exporting-traces) for more details. + +## Combine LiteLLM Trace with Your Application Trace + +LiteLLM is often part of larger LLM applications, such as agentic models. MLflow Tracing allows you to instrument custom Python code, which can then be combined with LiteLLM traces. + +```python +import litellm +import mlflow +from mlflow.entities import SpanType + +# Enable LiteLLM tracing +mlflow.litellm.autolog() + + +class CustomAgent: + # Use @mlflow.trace to instrument Python functions. + @mlflow.trace(span_type=SpanType.AGENT) + def run(self, query: str): + # do something + + while i < self.max_turns: + response = litellm.completion( + model="gpt-4o-mini", + messages=messages, + ) + + action = self.get_action(response) + ... + + @mlflow.trace + def get_action(llm_response): + ... +``` + +This approach generates a unified trace, combining your custom Python code with LiteLLM calls. + + +## Support + +* For advanced usage and integrations of tracing, visit the [MLflow Tracing documentation](https://mlflow.org/docs/latest/llms/tracing/index.html). +* For any question or issue with this integration, please [submit an issue](https://github.com/mlflow/mlflow/issues/new/choose) on our [Github](https://github.com/mlflow/mlflow) repository! \ No newline at end of file diff --git a/docs/my-website/docs/pass_through/anthropic_completion.md b/docs/my-website/docs/pass_through/anthropic_completion.md new file mode 100644 index 000000000..0c6a5f1b6 --- /dev/null +++ b/docs/my-website/docs/pass_through/anthropic_completion.md @@ -0,0 +1,282 @@ +# Anthropic `/v1/messages` + +Pass-through endpoints for Anthropic - call provider-specific endpoint, in native format (no translation). + +Just replace `https://api.anthropic.com` with `LITELLM_PROXY_BASE_URL/anthropic` 🚀 + +#### **Example Usage** +```bash +curl --request POST \ + --url http://0.0.0.0:4000/anthropic/v1/messages \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --header "Authorization: bearer sk-anything" \ + --data '{ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + }' +``` + +Supports **ALL** Anthropic Endpoints (including streaming). + +[**See All Anthropic Endpoints**](https://docs.anthropic.com/en/api/messages) + +## Quick Start + +Let's call the Anthropic [`/messages` endpoint](https://docs.anthropic.com/en/api/messages) + +1. Add Anthropic API Key to your environment + +```bash +export ANTHROPIC_API_KEY="" +``` + +2. Start LiteLLM Proxy + +```bash +litellm + +# RUNNING on http://0.0.0.0:4000 +``` + +3. Test it! + +Let's call the Anthropic /messages endpoint + +```bash +curl http://0.0.0.0:4000/anthropic/v1/messages \ + --header "x-api-key: $LITELLM_API_KEY" \ + --header "anthropic-version: 2023-06-01" \ + --header "content-type: application/json" \ + --data \ + '{ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + }' +``` + + +## Examples + +Anything after `http://0.0.0.0:4000/anthropic` is treated as a provider-specific route, and handled accordingly. + +Key Changes: + +| **Original Endpoint** | **Replace With** | +|------------------------------------------------------|-----------------------------------| +| `https://api.anthropic.com` | `http://0.0.0.0:4000/anthropic` (LITELLM_PROXY_BASE_URL="http://0.0.0.0:4000") | +| `bearer $ANTHROPIC_API_KEY` | `bearer anything` (use `bearer LITELLM_VIRTUAL_KEY` if Virtual Keys are setup on proxy) | + + +### **Example 1: Messages endpoint** + +#### LiteLLM Proxy Call + +```bash +curl --request POST \ + --url http://0.0.0.0:4000/anthropic/v1/messages \ + --header "x-api-key: $LITELLM_API_KEY" \ + --header "anthropic-version: 2023-06-01" \ + --header "content-type: application/json" \ + --data '{ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + }' +``` + +#### Direct Anthropic API Call + +```bash +curl https://api.anthropic.com/v1/messages \ + --header "x-api-key: $ANTHROPIC_API_KEY" \ + --header "anthropic-version: 2023-06-01" \ + --header "content-type: application/json" \ + --data \ + '{ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + }' +``` + +### **Example 2: Token Counting API** + +#### LiteLLM Proxy Call + +```bash +curl --request POST \ + --url http://0.0.0.0:4000/anthropic/v1/messages/count_tokens \ + --header "x-api-key: $LITELLM_API_KEY" \ + --header "anthropic-version: 2023-06-01" \ + --header "anthropic-beta: token-counting-2024-11-01" \ + --header "content-type: application/json" \ + --data \ + '{ + "model": "claude-3-5-sonnet-20241022", + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + }' +``` + +#### Direct Anthropic API Call + +```bash +curl https://api.anthropic.com/v1/messages/count_tokens \ + --header "x-api-key: $ANTHROPIC_API_KEY" \ + --header "anthropic-version: 2023-06-01" \ + --header "anthropic-beta: token-counting-2024-11-01" \ + --header "content-type: application/json" \ + --data \ +'{ + "model": "claude-3-5-sonnet-20241022", + "messages": [ + {"role": "user", "content": "Hello, world"} + ] +}' +``` + +### **Example 3: Batch Messages** + + +#### LiteLLM Proxy Call + +```bash +curl --request POST \ + --url http://0.0.0.0:4000/anthropic/v1/messages/batches \ + --header "x-api-key: $LITELLM_API_KEY" \ + --header "anthropic-version: 2023-06-01" \ + --header "anthropic-beta: message-batches-2024-09-24" \ + --header "content-type: application/json" \ + --data \ +'{ + "requests": [ + { + "custom_id": "my-first-request", + "params": { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + } + }, + { + "custom_id": "my-second-request", + "params": { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hi again, friend"} + ] + } + } + ] +}' +``` + +#### Direct Anthropic API Call + +```bash +curl https://api.anthropic.com/v1/messages/batches \ + --header "x-api-key: $ANTHROPIC_API_KEY" \ + --header "anthropic-version: 2023-06-01" \ + --header "anthropic-beta: message-batches-2024-09-24" \ + --header "content-type: application/json" \ + --data \ +'{ + "requests": [ + { + "custom_id": "my-first-request", + "params": { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + } + }, + { + "custom_id": "my-second-request", + "params": { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hi again, friend"} + ] + } + } + ] +}' +``` + + +## Advanced - Use with Virtual Keys + +Pre-requisites +- [Setup proxy with DB](../proxy/virtual_keys.md#setup) + +Use this, to avoid giving developers the raw Anthropic API key, but still letting them use Anthropic endpoints. + +### Usage + +1. Setup environment + +```bash +export DATABASE_URL="" +export LITELLM_MASTER_KEY="" +export COHERE_API_KEY="" +``` + +```bash +litellm + +# RUNNING on http://0.0.0.0:4000 +``` + +2. Generate virtual key + +```bash +curl -X POST 'http://0.0.0.0:4000/key/generate' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{}' +``` + +Expected Response + +```bash +{ + ... + "key": "sk-1234ewknldferwedojwojw" +} +``` + +3. Test it! + + +```bash +curl --request POST \ + --url http://0.0.0.0:4000/anthropic/v1/messages \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --header "Authorization: bearer sk-1234ewknldferwedojwojw" \ + --data '{ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] + }' +``` \ No newline at end of file diff --git a/docs/my-website/img/mlflow_tracing.png b/docs/my-website/img/mlflow_tracing.png new file mode 100644 index 000000000..aee1fb79e Binary files /dev/null and b/docs/my-website/img/mlflow_tracing.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 1dc33f554..dd8443a28 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -65,12 +65,12 @@ const sidebars = { }, { type: "category", - label: "Use with Provider SDKs", + label: "Pass-through Endpoints (Provider-specific)", items: [ "pass_through/vertex_ai", "pass_through/google_ai_studio", "pass_through/cohere", - "anthropic_completion", + "pass_through/anthropic_completion", "pass_through/bedrock", "pass_through/langfuse" ], diff --git a/litellm/__init__.py b/litellm/__init__.py index e8c3d6a64..edfe1a336 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -57,6 +57,7 @@ _custom_logger_compatible_callbacks_literal = Literal[ "gcs_bucket", "opik", "argilla", + "mlflow", ] logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None _known_custom_logger_compatible_callbacks: List = list( diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py new file mode 100644 index 000000000..7268350d1 --- /dev/null +++ b/litellm/integrations/mlflow.py @@ -0,0 +1,247 @@ +import json +import threading +from typing import Optional + +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger + + +class MlflowLogger(CustomLogger): + def __init__(self): + from mlflow.tracking import MlflowClient + + self._client = MlflowClient() + + self._stream_id_to_span = {} + self._lock = threading.Lock() # lock for _stream_id_to_span + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_success(kwargs, response_obj, start_time, end_time) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self._handle_success(kwargs, response_obj, start_time, end_time) + + def _handle_success(self, kwargs, response_obj, start_time, end_time): + """ + Log the success event as an MLflow span. + Note that this method is called asynchronously in the background thread. + """ + from mlflow.entities import SpanStatusCode + + try: + verbose_logger.debug("MLflow logging start for success event") + + if kwargs.get("stream"): + self._handle_stream_event(kwargs, response_obj, start_time, end_time) + else: + span = self._start_span_or_trace(kwargs, start_time) + end_time_ns = int(end_time.timestamp() * 1e9) + self._end_span_or_trace( + span=span, + outputs=response_obj, + status=SpanStatusCode.OK, + end_time_ns=end_time_ns, + ) + except Exception: + verbose_logger.debug("MLflow Logging Error", stack_info=True) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + self._handle_failure(kwargs, response_obj, start_time, end_time) + + def _handle_failure(self, kwargs, response_obj, start_time, end_time): + """ + Log the failure event as an MLflow span. + Note that this method is called *synchronously* unlike the success handler. + """ + from mlflow.entities import SpanEvent, SpanStatusCode + + try: + span = self._start_span_or_trace(kwargs, start_time) + + end_time_ns = int(end_time.timestamp() * 1e9) + + # Record exception info as event + if exception := kwargs.get("exception"): + span.add_event(SpanEvent.from_exception(exception)) + + self._end_span_or_trace( + span=span, + outputs=response_obj, + status=SpanStatusCode.ERROR, + end_time_ns=end_time_ns, + ) + + except Exception as e: + verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True) + + def _handle_stream_event(self, kwargs, response_obj, start_time, end_time): + """ + Handle the success event for a streaming response. For streaming calls, + log_success_event handle is triggered for every chunk of the stream. + We create a single span for the entire stream request as follows: + + 1. For the first chunk, start a new span and store it in the map. + 2. For subsequent chunks, add the chunk as an event to the span. + 3. For the final chunk, end the span and remove the span from the map. + """ + from mlflow.entities import SpanStatusCode + + litellm_call_id = kwargs.get("litellm_call_id") + + if litellm_call_id not in self._stream_id_to_span: + with self._lock: + # Check again after acquiring lock + if litellm_call_id not in self._stream_id_to_span: + # Start a new span for the first chunk of the stream + span = self._start_span_or_trace(kwargs, start_time) + self._stream_id_to_span[litellm_call_id] = span + + # Add chunk as event to the span + span = self._stream_id_to_span[litellm_call_id] + self._add_chunk_events(span, response_obj) + + # If this is the final chunk, end the span. The final chunk + # has complete_streaming_response that gathers the full response. + if final_response := kwargs.get("complete_streaming_response"): + end_time_ns = int(end_time.timestamp() * 1e9) + self._end_span_or_trace( + span=span, + outputs=final_response, + status=SpanStatusCode.OK, + end_time_ns=end_time_ns, + ) + + # Remove the stream_id from the map + with self._lock: + self._stream_id_to_span.pop(litellm_call_id) + + def _add_chunk_events(self, span, response_obj): + from mlflow.entities import SpanEvent + + try: + for choice in response_obj.choices: + span.add_event( + SpanEvent( + name="streaming_chunk", + attributes={"delta": json.dumps(choice.delta.model_dump())}, + ) + ) + except Exception: + verbose_logger.debug("Error adding chunk events to span", stack_info=True) + + def _construct_input(self, kwargs): + """Construct span inputs with optional parameters""" + inputs = {"messages": kwargs.get("messages")} + for key in ["functions", "tools", "stream", "tool_choice", "user"]: + if value := kwargs.get("optional_params", {}).pop(key, None): + inputs[key] = value + return inputs + + def _extract_attributes(self, kwargs): + """ + Extract span attributes from kwargs. + + With the latest version of litellm, the standard_logging_object contains + canonical information for logging. If it is not present, we extract + subset of attributes from other kwargs. + """ + attributes = { + "litellm_call_id": kwargs.get("litellm_call_id"), + "call_type": kwargs.get("call_type"), + "model": kwargs.get("model"), + } + standard_obj = kwargs.get("standard_logging_object") + if standard_obj: + attributes.update( + { + "api_base": standard_obj.get("api_base"), + "cache_hit": standard_obj.get("cache_hit"), + "usage": { + "completion_tokens": standard_obj.get("completion_tokens"), + "prompt_tokens": standard_obj.get("prompt_tokens"), + "total_tokens": standard_obj.get("total_tokens"), + }, + "raw_llm_response": standard_obj.get("response"), + "response_cost": standard_obj.get("response_cost"), + "saved_cache_cost": standard_obj.get("saved_cache_cost"), + } + ) + else: + litellm_params = kwargs.get("litellm_params", {}) + attributes.update( + { + "model": kwargs.get("model"), + "cache_hit": kwargs.get("cache_hit"), + "custom_llm_provider": kwargs.get("custom_llm_provider"), + "api_base": litellm_params.get("api_base"), + "response_cost": kwargs.get("response_cost"), + } + ) + return attributes + + def _get_span_type(self, call_type: Optional[str]) -> str: + from mlflow.entities import SpanType + + if call_type in ["completion", "acompletion"]: + return SpanType.LLM + elif call_type == "embeddings": + return SpanType.EMBEDDING + else: + return SpanType.LLM + + def _start_span_or_trace(self, kwargs, start_time): + """ + Start an MLflow span or a trace. + + If there is an active span, we start a new span as a child of + that span. Otherwise, we start a new trace. + """ + import mlflow + + call_type = kwargs.get("call_type", "completion") + span_name = f"litellm-{call_type}" + span_type = self._get_span_type(call_type) + start_time_ns = int(start_time.timestamp() * 1e9) + + inputs = self._construct_input(kwargs) + attributes = self._extract_attributes(kwargs) + + if active_span := mlflow.get_current_active_span(): # type: ignore + return self._client.start_span( + name=span_name, + request_id=active_span.request_id, + parent_id=active_span.span_id, + span_type=span_type, + inputs=inputs, + attributes=attributes, + start_time_ns=start_time_ns, + ) + else: + return self._client.start_trace( + name=span_name, + span_type=span_type, + inputs=inputs, + attributes=attributes, + start_time_ns=start_time_ns, + ) + + def _end_span_or_trace(self, span, outputs, end_time_ns, status): + """End an MLflow span or a trace.""" + if span.parent_id is None: + self._client.end_trace( + request_id=span.request_id, + outputs=outputs, + status=status, + end_time_ns=end_time_ns, + ) + else: + self._client.end_span( + request_id=span.request_id, + span_id=span.span_id, + outputs=outputs, + status=status, + end_time_ns=end_time_ns, + ) diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index bb94d54d5..05b4b9c48 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -161,17 +161,7 @@ def get_supported_openai_params( # noqa: PLR0915 elif custom_llm_provider == "huggingface": return litellm.HuggingfaceConfig().get_supported_openai_params() elif custom_llm_provider == "together_ai": - return [ - "stream", - "temperature", - "max_tokens", - "top_p", - "stop", - "frequency_penalty", - "tools", - "tool_choice", - "response_format", - ] + return litellm.TogetherAIConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "ai21": return [ "stream", diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 15f7f59fa..69d6adca4 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -28,6 +28,7 @@ from litellm.caching.caching_handler import LLMCachingHandler from litellm.cost_calculator import _select_model_name_for_cost_calc from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.mlflow import MlflowLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_custom_logger, redact_message_input_output_from_logging, @@ -563,6 +564,7 @@ class Logging: message=f"Model Call Details pre-call: {details_to_log}", level="info", ) + elif isinstance(callback, CustomLogger): # custom logger class callback.log_pre_api_call( model=self.model, @@ -1258,6 +1260,7 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) + if ( callback == "openmeter" and self.model_call_details.get("litellm_params", {}).get( @@ -2347,6 +2350,14 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_otel_logger) return _otel_logger # type: ignore + elif logging_integration == "mlflow": + for callback in _in_memory_loggers: + if isinstance(callback, MlflowLogger): + return callback # type: ignore + + _mlflow_logger = MlflowLogger() + _in_memory_loggers.append(_mlflow_logger) + return _mlflow_logger # type: ignore def get_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, @@ -2448,6 +2459,12 @@ def get_custom_logger_compatible_class( and callback.callback_name == "langtrace" ): return callback + + elif logging_integration == "mlflow": + for callback in _in_memory_loggers: + if isinstance(callback, MlflowLogger): + return callback + return None diff --git a/litellm/llms/together_ai/chat.py b/litellm/llms/together_ai/chat.py index 398bc489c..cb12d6147 100644 --- a/litellm/llms/together_ai/chat.py +++ b/litellm/llms/together_ai/chat.py @@ -6,8 +6,8 @@ Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. Docs: https://docs.together.ai/reference/completions-1 """ -from ..OpenAI.openai import OpenAIConfig +from ..OpenAI.chat.gpt_transformation import OpenAIGPTConfig -class TogetherAIConfig(OpenAIConfig): +class TogetherAIConfig(OpenAIGPTConfig): pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 34ac51481..92ca32e52 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1069,7 +1069,7 @@ async def update_cache( # noqa: PLR0915 end_user_id: Optional[str], team_id: Optional[str], response_cost: Optional[float], - parent_otel_span: Optional[Span], + parent_otel_span: Optional[Span], # type: ignore ): """ Use this to update the cache with new user spend. @@ -5657,6 +5657,13 @@ async def anthropic_response( # noqa: PLR0915 request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + """ + This is a BETA endpoint that calls 100+ LLMs in the anthropic format. + + To do a simple pass-through for anthropic, do `{PROXY_BASE_URL}/anthropic/v1/messages` + + Docs - https://docs.litellm.ai/docs/anthropic_completion + """ from litellm import adapter_completion from litellm.adapters.anthropic_adapter import anthropic_adapter diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index f6d36daaf..e0fa1e092 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -9,6 +9,9 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.spend_tracking.spend_tracking_utils import ( + get_spend_by_team_and_customer, +) router = APIRouter() @@ -932,6 +935,14 @@ async def get_global_spend_report( default=None, description="View spend for a specific internal_user_id. Example internal_user_id='1234", ), + team_id: Optional[str] = fastapi.Query( + default=None, + description="View spend for a specific team_id. Example team_id='1234", + ), + customer_id: Optional[str] = fastapi.Query( + default=None, + description="View spend for a specific customer_id. Example customer_id='1234. Can be used in conjunction with team_id as well.", + ), ): """ Get Daily Spend per Team, based on specific startTime and endTime. Per team, view usage by each key, model @@ -1074,8 +1085,12 @@ async def get_global_spend_report( return [] return db_response - + elif team_id is not None and customer_id is not None: + return await get_spend_by_team_and_customer( + start_date_obj, end_date_obj, team_id, customer_id, prisma_client + ) if group_by == "team": + # first get data from spend logs -> SpendByModelApiKey # then read data from "SpendByModelApiKey" to format the response obj sql_query = """ @@ -1305,7 +1320,6 @@ async def global_get_all_tag_names(): "/global/spend/tags", tags=["Budget & Spend Tracking"], dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, responses={ 200: {"model": List[LiteLLM_SpendLogs]}, }, diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 30e3ae5cd..48924d521 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -1,7 +1,9 @@ +import datetime import json import os import secrets import traceback +from datetime import datetime as dt from typing import Optional from pydantic import BaseModel @@ -9,7 +11,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload -from litellm.proxy.utils import hash_token +from litellm.proxy.utils import PrismaClient, hash_token def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: @@ -163,3 +165,79 @@ def get_logging_payload( "Error creating spendlogs object - {}".format(str(e)) ) raise e + + +async def get_spend_by_team_and_customer( + start_date: dt, + end_date: dt, + team_id: str, + customer_id: str, + prisma_client: PrismaClient, +): + sql_query = """ + WITH SpendByModelApiKey AS ( + SELECT + date_trunc('day', sl."startTime") AS group_by_day, + COALESCE(tt.team_alias, 'Unassigned Team') AS team_name, + sl.end_user AS customer, + sl.model, + sl.api_key, + SUM(sl.spend) AS model_api_spend, + SUM(sl.total_tokens) AS model_api_tokens + FROM + "LiteLLM_SpendLogs" sl + LEFT JOIN + "LiteLLM_TeamTable" tt + ON + sl.team_id = tt.team_id + WHERE + sl."startTime" BETWEEN $1::date AND $2::date + AND sl.team_id = $3 + AND sl.end_user = $4 + GROUP BY + date_trunc('day', sl."startTime"), + tt.team_alias, + sl.end_user, + sl.model, + sl.api_key + ) + SELECT + group_by_day, + jsonb_agg(jsonb_build_object( + 'team_name', team_name, + 'customer', customer, + 'total_spend', total_spend, + 'metadata', metadata + )) AS teams_customers + FROM ( + SELECT + group_by_day, + team_name, + customer, + SUM(model_api_spend) AS total_spend, + jsonb_agg(jsonb_build_object( + 'model', model, + 'api_key', api_key, + 'spend', model_api_spend, + 'total_tokens', model_api_tokens + )) AS metadata + FROM + SpendByModelApiKey + GROUP BY + group_by_day, + team_name, + customer + ) AS aggregated + GROUP BY + group_by_day + ORDER BY + group_by_day; + """ + + db_response = await prisma_client.db.query_raw( + sql_query, start_date, end_date, team_id, customer_id + ) + if db_response is None: + return [] + + return db_response diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index 667a21a3c..c4a64fa21 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -155,6 +155,51 @@ async def cohere_proxy_route( return received_value +@router.api_route( + "/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"] +) +async def anthropic_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + base_target_url = "https://api.anthropic.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY") + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={"x-api-key": "{}".format(anthropic_api_key)}, + _forward_headers=True, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + @router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def bedrock_proxy_route( endpoint: str, diff --git a/litellm/tests/test_mlflow.py b/litellm/tests/test_mlflow.py new file mode 100644 index 000000000..ec23875ea --- /dev/null +++ b/litellm/tests/test_mlflow.py @@ -0,0 +1,29 @@ +import pytest + +import litellm + + +def test_mlflow_logging(): + litellm.success_callback = ["mlflow"] + litellm.failure_callback = ["mlflow"] + + litellm.completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "what llm are u"}], + max_tokens=10, + temperature=0.2, + user="test-user", + ) + +@pytest.mark.asyncio() +async def test_async_mlflow_logging(): + litellm.success_callback = ["mlflow"] + litellm.failure_callback = ["mlflow"] + + await litellm.acompletion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi test from local arize"}], + mock_response="hello", + temperature=0.1, + user="OTEL_USER", + ) diff --git a/litellm/utils.py b/litellm/utils.py index fdb533e4e..f4f31e6cf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2903,24 +2903,16 @@ def get_optional_params( # noqa: PLR0915 ) _check_valid_arg(supported_params=supported_params) - if stream: - optional_params["stream"] = stream - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - if frequency_penalty is not None: - optional_params["frequency_penalty"] = frequency_penalty - if stop is not None: - optional_params["stop"] = stop - if tools is not None: - optional_params["tools"] = tools - if tool_choice is not None: - optional_params["tool_choice"] = tool_choice - if response_format is not None: - optional_params["response_format"] = response_format + optional_params = litellm.TogetherAIConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) elif custom_llm_provider == "ai21": ## check if unsupported param passed in supported_params = get_supported_openai_params( diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index bea066865..c9527c830 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -923,6 +923,14 @@ def test_watsonx_text_top_k(): assert optional_params["top_k"] == 10 + +def test_together_ai_model_params(): + optional_params = get_optional_params( + model="together_ai", custom_llm_provider="together_ai", logprobs=1 + ) + print(optional_params) + assert optional_params["logprobs"] == 1 + def test_forward_user_param(): from litellm.utils import get_supported_openai_params, get_optional_params diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 881185b74..3ce4cb7d7 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -406,8 +406,13 @@ def test_completion_claude_3_empty_response(): "content": "I was hoping we could chat a bit", }, ] - response = litellm.completion(model="claude-3-opus-20240229", messages=messages) - print(response) + try: + response = litellm.completion(model="claude-3-opus-20240229", messages=messages) + print(response) + except litellm.InternalServerError as e: + pytest.skip(f"InternalServerError - {str(e)}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") def test_completion_claude_3(): @@ -434,6 +439,8 @@ def test_completion_claude_3(): ) # Add any assertions, here to check response args print(response) + except litellm.InternalServerError as e: + pytest.skip(f"InternalServerError - {str(e)}") except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -917,6 +924,9 @@ def test_completion_base64(model): except litellm.ServiceUnavailableError as e: print("got service unavailable error: ", e) pass + except litellm.InternalServerError as e: + print("got internal server error: ", e) + pass except Exception as e: if "500 Internal error encountered.'" in str(e): pass @@ -1055,7 +1065,6 @@ def test_completion_mistral_api(): cost = litellm.completion_cost(completion_response=response) print("cost to make mistral completion=", cost) assert cost > 0.0 - assert response.model == "mistral/mistral-tiny" except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 209b38423..0bc6953f9 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -3333,8 +3333,8 @@ async def test_acompletion_function_call_with_streaming(model): validate_final_streaming_function_calling_chunk(chunk=chunk) idx += 1 # raise Exception("it worked! ") - except litellm.InternalServerError: - pass + except litellm.InternalServerError as e: + pytest.skip(f"InternalServerError - {str(e)}") except litellm.ServiceUnavailableError: pass except Exception as e: diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py index ffc58416d..ecfc305f9 100644 --- a/tests/logging_callback_tests/test_otel_logging.py +++ b/tests/logging_callback_tests/test_otel_logging.py @@ -144,6 +144,7 @@ def validate_raw_gen_ai_request_openai_streaming(span): "model", ["anthropic/claude-3-opus-20240229"], ) +@pytest.mark.flaky(retries=6, delay=2) def test_completion_claude_3_function_call_with_otel(model): litellm.set_verbose = True diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index ebc7dd33c..38883fa38 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -31,6 +31,7 @@ from litellm.integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger from litellm.integrations.opik.opik import OpikLogger from litellm.integrations.opentelemetry import OpenTelemetry +from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.argilla import ArgillaLogger from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler from unittest.mock import patch @@ -59,6 +60,7 @@ callback_class_str_to_classType = { "logfire": OpenTelemetry, "arize": OpenTelemetry, "langtrace": OpenTelemetry, + "mlflow": MlflowLogger, } expected_env_vars = {