litellm-mirror/litellm/integrations/mlflow.py
Krish Dholakia 3beecfb0d4
LiteLLM Minor Fixes & Improvements (11/13/2024) (#6729)
* fix(utils.py): add logprobs support for together ai

Fixes

https://github.com/BerriAI/litellm/issues/6724

* feat(pass_through_endpoints/): add anthropic/ pass-through endpoint

adds new `anthropic/` pass-through endpoint + refactors docs

* feat(spend_management_endpoints.py): allow /global/spend/report to query team + customer id

enables seeing spend for a customer in a team

* Add integration with MLflow Tracing (#6147)

* Add MLflow logger

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Streaming handling

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* lint

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* address comments and fix issues

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* address comments and fix issues

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Move logger construction code

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* Add docs

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* async handlers

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* new picture

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>

* fix(mlflow.py): fix ruff linting errors

* ci(config.yml): add mlflow to ci testing

* fix: fix test

* test: fix test

* Litellm key update fix (#6710)

* fix(caching): convert arg to equivalent kwargs in llm caching handler

prevent unexpected errors

* fix(caching_handler.py): don't pass args to caching

* fix(caching): remove all *args from caching.py

* fix(caching): consistent function signatures + abc method

* test(caching_unit_tests.py): add unit tests for llm caching

ensures coverage for common caching scenarios across different implementations

* refactor(litellm_logging.py): move to using cache key from hidden params instead of regenerating one

* fix(router.py): drop redis password requirement

* fix(proxy_server.py): fix faulty slack alerting check

* fix(langfuse.py): avoid copying functions/thread lock objects in metadata

fixes metadata copy error when parent otel span in metadata

* test: update test

* fix(key_management_endpoints.py): fix /key/update with metadata update

* fix(key_management_endpoints.py): fix key_prepare_update helper

* fix(key_management_endpoints.py): reset value to none if set in key update

* fix: update test

'

* Litellm dev 11 11 2024 (#6693)

* fix(__init__.py): add 'watsonx_text' as mapped llm api route

Fixes https://github.com/BerriAI/litellm/issues/6663

* fix(opentelemetry.py): fix passing parallel tool calls to otel

Fixes https://github.com/BerriAI/litellm/issues/6677

* refactor(test_opentelemetry_unit_tests.py): create a base set of unit tests for all logging integrations - test for parallel tool call handling

reduces bugs in repo

* fix(__init__.py): update provider-model mapping to include all known provider-model mappings

Fixes https://github.com/BerriAI/litellm/issues/6669

* feat(anthropic): support passing document in llm api call

* docs(anthropic.md): add pdf anthropic call to docs + expose new 'supports_pdf_input' function

* fix(factory.py): fix linting error

* add clear doc string for GCS bucket logging

* Add docs to export logs to Laminar (#6674)

* Add docs to export logs to Laminar

* minor fix: newline at end of file

* place laminar after http and grpc

* (Feat) Add langsmith key based logging (#6682)

* add langsmith_api_key to StandardCallbackDynamicParams

* create a file for langsmith types

* langsmith add key / team based logging

* add key based logging for langsmith

* fix langsmith key based logging

* fix linting langsmith

* remove NOQA violation

* add unit test coverage for all helpers in test langsmith

* test_langsmith_key_based_logging

* docs langsmith key based logging

* run langsmith tests in logging callback tests

* fix logging testing

* test_langsmith_key_based_logging

* test_add_callback_via_key_litellm_pre_call_utils_langsmith

* add debug statement langsmith key based logging

* test_langsmith_key_based_logging

* (fix) OpenAI's optional messages[].name  does not work with Mistral API  (#6701)

* use helper for _transform_messages mistral

* add test_message_with_name to base LLMChat test

* fix linting

* add xAI on Admin UI (#6680)

* (docs) add benchmarks on 1K RPS  (#6704)

* docs litellm proxy benchmarks

* docs GCS bucket

* doc fix - reduce clutter on logging doc title

* (feat) add cost tracking stable diffusion 3 on Bedrock  (#6676)

* add cost tracking for sd3

* test_image_generation_bedrock

* fix get model info for image cost

* add cost_calculator for stability 1 models

* add unit testing for bedrock image cost calc

* test_cost_calculator_with_no_optional_params

* add test_cost_calculator_basic

* correctly allow size Optional

* fix cost_calculator

* sd3 unit tests cost calc

* fix raise correct error 404 when /key/info is called on non-existent key  (#6653)

* fix raise correct error on /key/info

* add not_found_error error

* fix key not found in DB error

* use 1 helper for checking token hash

* fix error code on key info

* fix test key gen prisma

* test_generate_and_call_key_info

* test fix test_call_with_valid_model_using_all_models

* fix key info tests

* bump: version 1.52.4 → 1.52.5

* add defaults used for GCS logging

* LiteLLM Minor Fixes & Improvements (11/12/2024)  (#6705)

* fix(caching): convert arg to equivalent kwargs in llm caching handler

prevent unexpected errors

* fix(caching_handler.py): don't pass args to caching

* fix(caching): remove all *args from caching.py

* fix(caching): consistent function signatures + abc method

* test(caching_unit_tests.py): add unit tests for llm caching

ensures coverage for common caching scenarios across different implementations

* refactor(litellm_logging.py): move to using cache key from hidden params instead of regenerating one

* fix(router.py): drop redis password requirement

* fix(proxy_server.py): fix faulty slack alerting check

* fix(langfuse.py): avoid copying functions/thread lock objects in metadata

fixes metadata copy error when parent otel span in metadata

* test: update test

* bump: version 1.52.5 → 1.52.6

* (feat) helm hook to sync db schema  (#6715)

* v0 migration job

* fix job

* fix migrations job.yml

* handle standalone DB on helm hook

* fix argo cd annotations

* fix db migration helm hook

* fix migration job

* doc fix Using Http/2 with Hypercorn

* (fix proxy redis) Add redis sentinel support  (#6154)

* add sentinel_password support

* add doc for setting redis sentinel password

* fix redis sentinel - use sentinel password

* Fix: Update gpt-4o costs to that of gpt-4o-2024-08-06 (#6714)

Fixes #6713

* (fix) using Anthropic `response_format={"type": "json_object"}`  (#6721)

* add support for response_format=json anthropic

* add test_json_response_format to baseLLM ChatTest

* fix test_litellm_anthropic_prompt_caching_tools

* fix test_anthropic_function_call_with_no_schema

* test test_create_json_tool_call_for_response_format

* (feat) Add cost tracking for Azure Dall-e-3 Image Generation  + use base class to ensure basic image generation tests pass  (#6716)

* add BaseImageGenTest

* use 1 class for unit testing

* add debugging to BaseImageGenTest

* TestAzureOpenAIDalle3

* fix response_cost_calculator

* test_basic_image_generation

* fix img gen basic test

* fix _select_model_name_for_cost_calc

* fix test_aimage_generation_bedrock_with_optional_params

* fix undo changes cost tracking

* fix response_cost_calculator

* fix test_cost_azure_gpt_35

* fix remove dup test (#6718)

* (build) update db helm hook

* (build) helm db pre sync hook

* (build) helm db sync hook

* test: run test_team_logging firdst

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Dinmukhamed Mailibay <47117969+dinmukhamedm@users.noreply.github.com>
Co-authored-by: Kilian Lieret <kilian.lieret@posteo.de>

* test: update test

* test: skip anthropic overloaded error

* test: cleanup test

* test: update tests

* test: fix test

* test: handle gemini overloaded model error

* test: handle internal server error

* test: handle anthropic overloaded error

* test: handle claude instability

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Co-authored-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com>
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Dinmukhamed Mailibay <47117969+dinmukhamedm@users.noreply.github.com>
Co-authored-by: Kilian Lieret <kilian.lieret@posteo.de>
2024-11-15 11:18:31 +05:30

247 lines
9.4 KiB
Python

import json
import threading
from typing import Optional
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
class MlflowLogger(CustomLogger):
def __init__(self):
from mlflow.tracking import MlflowClient
self._client = MlflowClient()
self._stream_id_to_span = {}
self._lock = threading.Lock() # lock for _stream_id_to_span
def log_success_event(self, kwargs, response_obj, start_time, end_time):
self._handle_success(kwargs, response_obj, start_time, end_time)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self._handle_success(kwargs, response_obj, start_time, end_time)
def _handle_success(self, kwargs, response_obj, start_time, end_time):
"""
Log the success event as an MLflow span.
Note that this method is called asynchronously in the background thread.
"""
from mlflow.entities import SpanStatusCode
try:
verbose_logger.debug("MLflow logging start for success event")
if kwargs.get("stream"):
self._handle_stream_event(kwargs, response_obj, start_time, end_time)
else:
span = self._start_span_or_trace(kwargs, start_time)
end_time_ns = int(end_time.timestamp() * 1e9)
self._end_span_or_trace(
span=span,
outputs=response_obj,
status=SpanStatusCode.OK,
end_time_ns=end_time_ns,
)
except Exception:
verbose_logger.debug("MLflow Logging Error", stack_info=True)
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
self._handle_failure(kwargs, response_obj, start_time, end_time)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
self._handle_failure(kwargs, response_obj, start_time, end_time)
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
"""
Log the failure event as an MLflow span.
Note that this method is called *synchronously* unlike the success handler.
"""
from mlflow.entities import SpanEvent, SpanStatusCode
try:
span = self._start_span_or_trace(kwargs, start_time)
end_time_ns = int(end_time.timestamp() * 1e9)
# Record exception info as event
if exception := kwargs.get("exception"):
span.add_event(SpanEvent.from_exception(exception))
self._end_span_or_trace(
span=span,
outputs=response_obj,
status=SpanStatusCode.ERROR,
end_time_ns=end_time_ns,
)
except Exception as e:
verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True)
def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
"""
Handle the success event for a streaming response. For streaming calls,
log_success_event handle is triggered for every chunk of the stream.
We create a single span for the entire stream request as follows:
1. For the first chunk, start a new span and store it in the map.
2. For subsequent chunks, add the chunk as an event to the span.
3. For the final chunk, end the span and remove the span from the map.
"""
from mlflow.entities import SpanStatusCode
litellm_call_id = kwargs.get("litellm_call_id")
if litellm_call_id not in self._stream_id_to_span:
with self._lock:
# Check again after acquiring lock
if litellm_call_id not in self._stream_id_to_span:
# Start a new span for the first chunk of the stream
span = self._start_span_or_trace(kwargs, start_time)
self._stream_id_to_span[litellm_call_id] = span
# Add chunk as event to the span
span = self._stream_id_to_span[litellm_call_id]
self._add_chunk_events(span, response_obj)
# If this is the final chunk, end the span. The final chunk
# has complete_streaming_response that gathers the full response.
if final_response := kwargs.get("complete_streaming_response"):
end_time_ns = int(end_time.timestamp() * 1e9)
self._end_span_or_trace(
span=span,
outputs=final_response,
status=SpanStatusCode.OK,
end_time_ns=end_time_ns,
)
# Remove the stream_id from the map
with self._lock:
self._stream_id_to_span.pop(litellm_call_id)
def _add_chunk_events(self, span, response_obj):
from mlflow.entities import SpanEvent
try:
for choice in response_obj.choices:
span.add_event(
SpanEvent(
name="streaming_chunk",
attributes={"delta": json.dumps(choice.delta.model_dump())},
)
)
except Exception:
verbose_logger.debug("Error adding chunk events to span", stack_info=True)
def _construct_input(self, kwargs):
"""Construct span inputs with optional parameters"""
inputs = {"messages": kwargs.get("messages")}
for key in ["functions", "tools", "stream", "tool_choice", "user"]:
if value := kwargs.get("optional_params", {}).pop(key, None):
inputs[key] = value
return inputs
def _extract_attributes(self, kwargs):
"""
Extract span attributes from kwargs.
With the latest version of litellm, the standard_logging_object contains
canonical information for logging. If it is not present, we extract
subset of attributes from other kwargs.
"""
attributes = {
"litellm_call_id": kwargs.get("litellm_call_id"),
"call_type": kwargs.get("call_type"),
"model": kwargs.get("model"),
}
standard_obj = kwargs.get("standard_logging_object")
if standard_obj:
attributes.update(
{
"api_base": standard_obj.get("api_base"),
"cache_hit": standard_obj.get("cache_hit"),
"usage": {
"completion_tokens": standard_obj.get("completion_tokens"),
"prompt_tokens": standard_obj.get("prompt_tokens"),
"total_tokens": standard_obj.get("total_tokens"),
},
"raw_llm_response": standard_obj.get("response"),
"response_cost": standard_obj.get("response_cost"),
"saved_cache_cost": standard_obj.get("saved_cache_cost"),
}
)
else:
litellm_params = kwargs.get("litellm_params", {})
attributes.update(
{
"model": kwargs.get("model"),
"cache_hit": kwargs.get("cache_hit"),
"custom_llm_provider": kwargs.get("custom_llm_provider"),
"api_base": litellm_params.get("api_base"),
"response_cost": kwargs.get("response_cost"),
}
)
return attributes
def _get_span_type(self, call_type: Optional[str]) -> str:
from mlflow.entities import SpanType
if call_type in ["completion", "acompletion"]:
return SpanType.LLM
elif call_type == "embeddings":
return SpanType.EMBEDDING
else:
return SpanType.LLM
def _start_span_or_trace(self, kwargs, start_time):
"""
Start an MLflow span or a trace.
If there is an active span, we start a new span as a child of
that span. Otherwise, we start a new trace.
"""
import mlflow
call_type = kwargs.get("call_type", "completion")
span_name = f"litellm-{call_type}"
span_type = self._get_span_type(call_type)
start_time_ns = int(start_time.timestamp() * 1e9)
inputs = self._construct_input(kwargs)
attributes = self._extract_attributes(kwargs)
if active_span := mlflow.get_current_active_span(): # type: ignore
return self._client.start_span(
name=span_name,
request_id=active_span.request_id,
parent_id=active_span.span_id,
span_type=span_type,
inputs=inputs,
attributes=attributes,
start_time_ns=start_time_ns,
)
else:
return self._client.start_trace(
name=span_name,
span_type=span_type,
inputs=inputs,
attributes=attributes,
start_time_ns=start_time_ns,
)
def _end_span_or_trace(self, span, outputs, end_time_ns, status):
"""End an MLflow span or a trace."""
if span.parent_id is None:
self._client.end_trace(
request_id=span.request_id,
outputs=outputs,
status=status,
end_time_ns=end_time_ns,
)
else:
self._client.end_span(
request_id=span.request_id,
span_id=span.span_id,
outputs=outputs,
status=status,
end_time_ns=end_time_ns,
)