refactor(litellm_logging.py): refactors how slack_alerting generates langfuse trace url

gets the url from logging object
This commit is contained in:
Krrish Dholakia 2024-06-21 16:12:25 -07:00
parent 174b345766
commit c7b06c42b7
5 changed files with 207 additions and 41 deletions

View file

@ -24,6 +24,7 @@ import litellm.types
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import AlertType, CallInfo, UserAPIKeyAuth, WebhookEvent
from litellm.types.router import LiteLLM_Params
@ -229,7 +230,7 @@ class SlackAlerting(CustomLogger):
"db_exceptions",
]
def _add_langfuse_trace_id_to_alert(
async def _add_langfuse_trace_id_to_alert(
self,
request_data: Optional[dict] = None,
) -> Optional[str]:
@ -242,21 +243,19 @@ class SlackAlerting(CustomLogger):
-> litellm_call_id
"""
# do nothing for now
if request_data is not None:
trace_id = None
if (
request_data.get("metadata", {}).get("existing_trace_id", None)
is not None
):
trace_id = request_data["metadata"]["existing_trace_id"]
elif request_data.get("metadata", {}).get("trace_id", None) is not None:
trace_id = request_data["metadata"]["trace_id"]
elif request_data.get("litellm_logging_obj", None) is not None and hasattr(
request_data["litellm_logging_obj"], "model_call_details"
):
trace_id = request_data["litellm_logging_obj"].model_call_details[
"litellm_call_id"
]
if (
request_data is not None
and request_data.get("litellm_logging_obj", None) is not None
):
trace_id: Optional[str] = None
litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
for _ in range(3):
trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
if trace_id is not None:
break
await asyncio.sleep(3) # wait 3s before retrying for trace id
if litellm.litellm_core_utils.litellm_logging.langFuseLogger is not None:
base_url = (
litellm.litellm_core_utils.litellm_logging.langFuseLogger.Langfuse.base_url
@ -645,7 +644,7 @@ class SlackAlerting(CustomLogger):
)
if "langfuse" in litellm.success_callback:
langfuse_url = self._add_langfuse_trace_id_to_alert(
langfuse_url = await self._add_langfuse_trace_id_to_alert(
request_data=request_data,
)

View file

@ -10,7 +10,7 @@ import sys
import time
import traceback
import uuid
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Literal, Optional
import litellm
from litellm import (
@ -19,7 +19,7 @@ from litellm import (
turn_off_message_logging,
verbose_logger,
)
from litellm.caching import S3Cache
from litellm.caching import InMemoryCache, S3Cache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
@ -111,7 +111,25 @@ additional_details: Optional[Dict[str, str]] = {}
local_cache: Optional[Dict[str, str]] = {}
last_fetched_at = None
last_fetched_at_keys = None
####
class ServiceTraceIDCache:
def __init__(self) -> None:
self.cache = InMemoryCache()
def get_cache(self, litellm_call_id: str, service_name: str) -> Optional[str]:
key_name = "{}:{}".format(service_name, litellm_call_id)
response = self.cache.get_cache(key=key_name)
return response
def set_cache(self, litellm_call_id: str, service_name: str, trace_id: str) -> None:
key_name = "{}:{}".format(service_name, litellm_call_id)
self.cache.set_cache(key=key_name, value=trace_id)
return None
in_memory_trace_id_cache = ServiceTraceIDCache()
class Logging:
@ -821,7 +839,7 @@ class Logging:
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
langFuseLogger.log_event(
_response = langFuseLogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
@ -829,6 +847,14 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
)
if callback == "datadog":
global dataDogLogger
verbose_logger.debug("reaches datadog for success logging!")
@ -1607,7 +1633,7 @@ class Logging:
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
langFuseLogger.log_event(
_response = langFuseLogger.log_event(
start_time=start_time,
end_time=end_time,
response_obj=None,
@ -1617,6 +1643,14 @@ class Logging:
level="ERROR",
kwargs=self.model_call_details,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
)
if callback == "traceloop":
traceloopLogger.log_event(
start_time=start_time,
@ -1721,6 +1755,24 @@ class Logging:
)
)
def _get_trace_id(self, service_name: Literal["langfuse"]) -> Optional[str]:
"""
For the given service (e.g. langfuse), return the trace_id actually logged.
Used for constructing the url in slack alerting.
Returns:
- str: The logged trace id
- None: If trace id not yet emitted.
"""
trace_id: Optional[str] = None
if service_name == "langfuse":
trace_id = in_memory_trace_id_cache.get_cache(
litellm_call_id=self.litellm_call_id, service_name=service_name
)
return trace_id
def set_callbacks(callback_list, function_id=None):
"""

View file

@ -463,7 +463,7 @@ class ProxyLogging:
alerting_metadata = {}
if request_data is not None:
_url = self.slack_alerting_instance._add_langfuse_trace_id_to_alert(
_url = await self.slack_alerting_instance._add_langfuse_trace_id_to_alert(
request_data=request_data
)

View file

@ -1,33 +1,37 @@
# What is this?
## Tests slack alerting on proxy logging object
import sys, json, uuid, random, httpx
import asyncio
import io
import json
import os
import io, asyncio
import random
import sys
import time
import uuid
from datetime import datetime, timedelta
from typing import Optional
import httpx
# import logging
# logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath("../.."))
from litellm.proxy.utils import ProxyLogging
from litellm.caching import DualCache, RedisCache
import litellm
import pytest
import asyncio
from unittest.mock import patch, MagicMock
from litellm.utils import get_api_base
from litellm.caching import DualCache
from litellm.integrations.slack_alerting import SlackAlerting, DeploymentMetrics
import unittest.mock
from unittest.mock import AsyncMock
import pytest
from litellm.router import AlertingConfig, Router
from litellm.proxy._types import CallInfo
from openai import APIError
from litellm.router import AlertingConfig
import litellm
import os
import unittest.mock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openai import APIError
import litellm
from litellm.caching import DualCache, RedisCache
from litellm.integrations.slack_alerting import DeploymentMetrics, SlackAlerting
from litellm.proxy._types import CallInfo
from litellm.proxy.utils import ProxyLogging
from litellm.router import AlertingConfig, Router
from litellm.utils import get_api_base
@pytest.mark.parametrize(
@ -123,8 +127,8 @@ def test_init():
print("passed testing slack alerting init")
from unittest.mock import patch, AsyncMock
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, patch
@pytest.fixture
@ -805,3 +809,53 @@ async def test_alerting():
pass
finally:
await asyncio.sleep(3)
@pytest.mark.asyncio
async def test_langfuse_trace_id():
"""
- Unit test for `_add_langfuse_trace_id_to_alert` function in slack_alerting.py
"""
from litellm.litellm_core_utils.litellm_logging import Logging
litellm.success_callback = ["langfuse"]
litellm_logging_obj = Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey how's it going?"}],
mock_response="Hey!",
litellm_logging_obj=litellm_logging_obj,
)
await asyncio.sleep(3)
assert litellm_logging_obj._get_trace_id(service_name="langfuse") is not None
slack_alerting = SlackAlerting(
alerting_threshold=32,
alerting=["slack"],
alert_types=["llm_exceptions"],
internal_usage_cache=DualCache(),
)
trace_url = await slack_alerting._add_langfuse_trace_id_to_alert(
request_data={"litellm_logging_obj": litellm_logging_obj}
)
assert trace_url is not None
returned_trace_id = int(trace_url.split("/")[-1])
assert returned_trace_id == int(
litellm_logging_obj._get_trace_id(service_name="langfuse")
)

View file

@ -1,5 +1,6 @@
import copy
import sys
import time
from datetime import datetime
from unittest import mock
@ -548,3 +549,63 @@ def test_get_llm_provider_ft_models():
model, custom_llm_provider, _, _ = get_llm_provider(model="ft:gpt-4o-2024-05-13")
assert custom_llm_provider == "openai"
@pytest.mark.parametrize("langfuse_trace_id", [None, "my-unique-trace-id"])
@pytest.mark.parametrize(
"langfuse_existing_trace_id", [None, "my-unique-existing-trace-id"]
)
def test_logging_trace_id(langfuse_trace_id, langfuse_existing_trace_id):
"""
- Unit test for `_get_trace_id` function in Logging obj
"""
from litellm.litellm_core_utils.litellm_logging import Logging
litellm.success_callback = ["langfuse"]
litellm_call_id = "my-unique-call-id"
litellm_logging_obj = Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id=litellm_call_id,
start_time=datetime.now(),
function_id="1234",
)
metadata = {}
if langfuse_trace_id is not None:
metadata["trace_id"] = langfuse_trace_id
if langfuse_existing_trace_id is not None:
metadata["existing_trace_id"] = langfuse_existing_trace_id
litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey how's it going?"}],
mock_response="Hey!",
litellm_logging_obj=litellm_logging_obj,
metadata=metadata,
)
time.sleep(3)
assert litellm_logging_obj._get_trace_id(service_name="langfuse") is not None
## if existing_trace_id exists
if langfuse_existing_trace_id is not None:
assert (
litellm_logging_obj._get_trace_id(service_name="langfuse")
== langfuse_existing_trace_id
)
## if trace_id exists
elif langfuse_trace_id is not None:
assert (
litellm_logging_obj._get_trace_id(service_name="langfuse")
== langfuse_trace_id
)
## if existing_trace_id exists
else:
assert (
litellm_logging_obj._get_trace_id(service_name="langfuse")
== litellm_call_id
)