(fix) - proxy reliability, ensure duplicate callbacks are not added to proxy (#8067)

* refactor _add_callbacks_from_db_config

* fix check for _custom_logger_exists_in_litellm_callbacks

* move loc of test utils

* run ci/cd again

* test_add_custom_logger_callback_to_specific_event_with_duplicates_callbacks

* fix _custom_logger_class_exists_in_success_callbacks

* unit testing for test_add_callbacks_from_db_config

* test_custom_logger_exists_in_callbacks_individual_functions

* fix config.yml

* fix test test_stream_chunk_builder_openai_audio_output_usage - use direct dict comparison
This commit is contained in:
Ishaan Jaff 2025-01-28 21:01:56 -08:00 committed by GitHub
parent a9e6c09776
commit 05111a1f1c
5 changed files with 398 additions and 15 deletions

View file

@ -691,6 +691,7 @@ jobs:
pip install "pytest-cov==5.0.0" pip install "pytest-cov==5.0.0"
pip install "google-generativeai==0.3.2" pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0" pip install "google-cloud-aiplatform==1.43.0"
pip install numpydoc
# Run pytest and generate JUnit XML report # Run pytest and generate JUnit XML report
- run: - run:
name: Run tests name: Run tests

View file

@ -280,7 +280,7 @@ from litellm.types.router import RouterGeneralSettings, updateDeployment
from litellm.types.utils import CustomHuggingfaceTokenizer from litellm.types.utils import CustomHuggingfaceTokenizer
from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking from litellm.utils import _add_custom_logger_callback_to_specific_event
try: try:
from litellm._version import version from litellm._version import version
@ -2401,13 +2401,12 @@ class ProxyConfig:
added_models += 1 added_models += 1
return added_models return added_models
async def _update_llm_router( # noqa: PLR0915 async def _update_llm_router(
self, self,
new_models: list, new_models: list,
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
): ):
global llm_router, llm_model_list, master_key, general_settings global llm_router, llm_model_list, master_key, general_settings
import base64
try: try:
if llm_router is None and master_key is not None: if llm_router is None and master_key is not None:
@ -2463,21 +2462,60 @@ class ProxyConfig:
# check if user set any callbacks in Config Table # check if user set any callbacks in Config Table
config_data = await proxy_config.get_config() config_data = await proxy_config.get_config()
self._add_callbacks_from_db_config(config_data)
# we need to set env variables too
self._add_environment_variables_from_db_config(config_data)
# router settings
await self._add_router_settings_from_db_config(
config_data=config_data, llm_router=llm_router, prisma_client=prisma_client
)
# general settings
self._add_general_settings_from_db_config(
config_data=config_data,
general_settings=general_settings,
proxy_logging_obj=proxy_logging_obj,
)
def _add_callbacks_from_db_config(self, config_data: dict) -> None:
"""
Adds callbacks from DB config to litellm
"""
litellm_settings = config_data.get("litellm_settings", {}) or {} litellm_settings = config_data.get("litellm_settings", {}) or {}
success_callbacks = litellm_settings.get("success_callback", None) success_callbacks = litellm_settings.get("success_callback", None)
failure_callbacks = litellm_settings.get("failure_callback", None) failure_callbacks = litellm_settings.get("failure_callback", None)
if success_callbacks is not None and isinstance(success_callbacks, list): if success_callbacks is not None and isinstance(success_callbacks, list):
for success_callback in success_callbacks: for success_callback in success_callbacks:
if success_callback not in litellm.success_callback: if (
success_callback
in litellm._known_custom_logger_compatible_callbacks
):
_add_custom_logger_callback_to_specific_event(
success_callback, "success"
)
elif success_callback not in litellm.success_callback:
litellm.success_callback.append(success_callback) litellm.success_callback.append(success_callback)
# Add failure callbacks from DB to litellm # Add failure callbacks from DB to litellm
if failure_callbacks is not None and isinstance(failure_callbacks, list): if failure_callbacks is not None and isinstance(failure_callbacks, list):
for failure_callback in failure_callbacks: for failure_callback in failure_callbacks:
if failure_callback not in litellm.failure_callback: if (
failure_callback
in litellm._known_custom_logger_compatible_callbacks
):
_add_custom_logger_callback_to_specific_event(
failure_callback, "failure"
)
elif failure_callback not in litellm.failure_callback:
litellm.failure_callback.append(failure_callback) litellm.failure_callback.append(failure_callback)
# we need to set env variables too
def _add_environment_variables_from_db_config(self, config_data: dict) -> None:
"""
Adds environment variables from DB config to litellm
"""
environment_variables = config_data.get("environment_variables", {}) environment_variables = config_data.get("environment_variables", {})
for k, v in environment_variables.items(): for k, v in environment_variables.items():
try: try:
@ -2489,7 +2527,15 @@ class ProxyConfig:
"Error setting env variable: %s - %s", k, str(e) "Error setting env variable: %s - %s", k, str(e)
) )
# router settings async def _add_router_settings_from_db_config(
self,
config_data: dict,
llm_router: Optional[Router],
prisma_client: Optional[PrismaClient],
) -> None:
"""
Adds router settings from DB config to litellm proxy
"""
if llm_router is not None and prisma_client is not None: if llm_router is not None and prisma_client is not None:
db_router_settings = await prisma_client.db.litellm_config.find_first( db_router_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "router_settings"} where={"param_name": "router_settings"}
@ -2501,7 +2547,17 @@ class ProxyConfig:
_router_settings = db_router_settings.param_value _router_settings = db_router_settings.param_value
llm_router.update_settings(**_router_settings) llm_router.update_settings(**_router_settings)
## ALERTING ## [TODO] move this to the _update_general_settings() block def _add_general_settings_from_db_config(
self, config_data: dict, general_settings: dict, proxy_logging_obj: ProxyLogging
) -> None:
"""
Adds general settings from DB config to litellm proxy
Args:
config_data: dict
general_settings: dict - global general_settings currently in use
proxy_logging_obj: ProxyLogging
"""
_general_settings = config_data.get("general_settings", {}) _general_settings = config_data.get("general_settings", {})
if "alerting" in _general_settings: if "alerting" in _general_settings:
if ( if (

View file

@ -346,11 +346,12 @@ def _add_custom_logger_callback_to_specific_event(
llm_router=None, llm_router=None,
) )
# don't double add a callback if callback_class:
if callback_class is not None and not any( if (
isinstance(cb, type(callback_class)) for cb in litellm.callbacks # type: ignore logging_event == "success"
): and _custom_logger_class_exists_in_success_callbacks(callback_class)
if logging_event == "success": is False
):
litellm.success_callback.append(callback_class) litellm.success_callback.append(callback_class)
litellm._async_success_callback.append(callback_class) litellm._async_success_callback.append(callback_class)
if callback in litellm.success_callback: if callback in litellm.success_callback:
@ -361,7 +362,11 @@ def _add_custom_logger_callback_to_specific_event(
litellm._async_success_callback.remove( litellm._async_success_callback.remove(
callback callback
) # remove the string from the callback list ) # remove the string from the callback list
elif logging_event == "failure": elif (
logging_event == "failure"
and _custom_logger_class_exists_in_failure_callbacks(callback_class)
is False
):
litellm.failure_callback.append(callback_class) litellm.failure_callback.append(callback_class)
litellm._async_failure_callback.append(callback_class) litellm._async_failure_callback.append(callback_class)
if callback in litellm.failure_callback: if callback in litellm.failure_callback:
@ -374,6 +379,38 @@ def _add_custom_logger_callback_to_specific_event(
) # remove the string from the callback list ) # remove the string from the callback list
def _custom_logger_class_exists_in_success_callbacks(
callback_class: CustomLogger,
) -> bool:
"""
Returns True if an instance of the custom logger exists in litellm.success_callback or litellm._async_success_callback
e.g if `LangfusePromptManagement` is passed in, it will return True if an instance of `LangfusePromptManagement` exists in litellm.success_callback or litellm._async_success_callback
Prevents double adding a custom logger callback to the litellm callbacks
"""
return any(
isinstance(cb, type(callback_class))
for cb in litellm.success_callback + litellm._async_success_callback
)
def _custom_logger_class_exists_in_failure_callbacks(
callback_class: CustomLogger,
) -> bool:
"""
Returns True if an instance of the custom logger exists in litellm.failure_callback or litellm._async_failure_callback
e.g if `LangfusePromptManagement` is passed in, it will return True if an instance of `LangfusePromptManagement` exists in litellm.failure_callback or litellm._async_failure_callback
Prevents double adding a custom logger callback to the litellm callbacks
"""
return any(
isinstance(cb, type(callback_class))
for cb in litellm.failure_callback + litellm._async_failure_callback
)
def function_setup( # noqa: PLR0915 def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.

View file

@ -13,7 +13,7 @@ import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system-path
import pytest import pytest
import litellm import litellm
@ -1529,6 +1529,215 @@ def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch):
assert len(litellm.failure_callback) == curr_len_failure_callback assert len(litellm.failure_callback) == curr_len_failure_callback
def test_custom_logger_exists_in_callbacks_individual_functions(monkeypatch):
"""
Test _custom_logger_class_exists_in_success_callbacks and _custom_logger_class_exists_in_failure_callbacks helper functions
Tests if logger is found in different callback lists
"""
from litellm.integrations.custom_logger import CustomLogger
from litellm.utils import (
_custom_logger_class_exists_in_failure_callbacks,
_custom_logger_class_exists_in_success_callbacks,
)
# Create a mock CustomLogger class
class MockCustomLogger(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
# Reset all callback lists
for list_name in [
"callbacks",
"_async_success_callback",
"_async_failure_callback",
"success_callback",
"failure_callback",
]:
monkeypatch.setattr(litellm, list_name, [])
mock_logger = MockCustomLogger()
# Test 1: No logger exists in any callback list
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == False
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == False
# Test 2: Logger exists in success_callback
litellm.success_callback.append(mock_logger)
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == True
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == False
# Reset callbacks
litellm.success_callback = []
# Test 3: Logger exists in _async_success_callback
litellm._async_success_callback.append(mock_logger)
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == True
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == False
# Reset callbacks
litellm._async_success_callback = []
# Test 4: Logger exists in failure_callback
litellm.failure_callback.append(mock_logger)
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == False
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == True
# Reset callbacks
litellm.failure_callback = []
# Test 5: Logger exists in _async_failure_callback
litellm._async_failure_callback.append(mock_logger)
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == False
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == True
# Test 6: Logger exists in both success and failure callbacks
litellm.success_callback.append(mock_logger)
litellm.failure_callback.append(mock_logger)
assert _custom_logger_class_exists_in_success_callbacks(mock_logger) == True
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger) == True
# Test 7: Different instance of same logger class
mock_logger_2 = MockCustomLogger()
assert _custom_logger_class_exists_in_success_callbacks(mock_logger_2) == True
assert _custom_logger_class_exists_in_failure_callbacks(mock_logger_2) == True
@pytest.mark.asyncio
async def test_add_custom_logger_callback_to_specific_event_with_duplicates(
monkeypatch,
):
"""
Test that when a callback exists in both success_callback and _async_success_callback,
it's not added again
"""
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
# Reset all callback lists
monkeypatch.setattr(litellm, "callbacks", [])
monkeypatch.setattr(litellm, "_async_success_callback", [])
monkeypatch.setattr(litellm, "_async_failure_callback", [])
monkeypatch.setattr(litellm, "success_callback", [])
monkeypatch.setattr(litellm, "failure_callback", [])
# Add logger to both success_callback and _async_success_callback
langfuse_logger = LangfusePromptManagement()
litellm.success_callback.append(langfuse_logger)
litellm._async_success_callback.append(langfuse_logger)
# Get initial lengths
initial_success_callback_len = len(litellm.success_callback)
initial_async_success_callback_len = len(litellm._async_success_callback)
# Make a completion call
await litellm.acompletion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Testing duplicate callbacks",
)
# Assert no new callbacks were added
assert len(litellm.success_callback) == initial_success_callback_len
assert len(litellm._async_success_callback) == initial_async_success_callback_len
@pytest.mark.asyncio
async def test_add_custom_logger_callback_to_specific_event_with_duplicates_success_callback(
monkeypatch,
):
"""
Test that when a callback exists in both success_callback and _async_success_callback,
it's not added again
"""
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
# Reset all callback lists
monkeypatch.setattr(litellm, "callbacks", [])
monkeypatch.setattr(litellm, "_async_success_callback", [])
monkeypatch.setattr(litellm, "_async_failure_callback", [])
monkeypatch.setattr(litellm, "success_callback", [])
monkeypatch.setattr(litellm, "failure_callback", [])
# Add logger to both success_callback and _async_success_callback
langfuse_logger = LangfusePromptManagement()
litellm.success_callback.append(langfuse_logger)
# Get initial lengths
initial_success_callback_len = len(litellm.success_callback)
initial_async_success_callback_len = len(litellm._async_success_callback)
# Make a completion call
await litellm.acompletion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Testing duplicate callbacks",
)
# Assert no new callbacks were added
assert len(litellm.success_callback) == initial_success_callback_len
assert len(litellm._async_success_callback) == initial_async_success_callback_len
@pytest.mark.asyncio
async def test_add_custom_logger_callback_to_specific_event_with_duplicates_callbacks(
monkeypatch,
):
"""
Test that when a callback exists in both success_callback and _async_success_callback,
it's not added again
"""
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
# Reset all callback lists
monkeypatch.setattr(litellm, "callbacks", [])
monkeypatch.setattr(litellm, "_async_success_callback", [])
monkeypatch.setattr(litellm, "_async_failure_callback", [])
monkeypatch.setattr(litellm, "success_callback", [])
monkeypatch.setattr(litellm, "failure_callback", [])
# Add logger to both success_callback and _async_success_callback
langfuse_logger = LangfusePromptManagement()
litellm.callbacks.append(langfuse_logger)
# Make a completion call
await litellm.acompletion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Testing duplicate callbacks",
)
# Assert no new callbacks were added
initial_callbacks_len = len(litellm.callbacks)
initial_async_success_callback_len = len(litellm._async_success_callback)
initial_success_callback_len = len(litellm.success_callback)
print(
f"Num callbacks before: litellm.callbacks: {len(litellm.callbacks)}, litellm._async_success_callback: {len(litellm._async_success_callback)}, litellm.success_callback: {len(litellm.success_callback)}"
)
for _ in range(10):
await litellm.acompletion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Testing duplicate callbacks",
)
assert len(litellm.callbacks) == initial_callbacks_len
assert len(litellm._async_success_callback) == initial_async_success_callback_len
assert len(litellm.success_callback) == initial_success_callback_len
print(
f"Num callbacks after 10 mock calls: litellm.callbacks: {len(litellm.callbacks)}, litellm._async_success_callback: {len(litellm._async_success_callback)}, litellm.success_callback: {len(litellm.success_callback)}"
)
def test_add_custom_logger_callback_to_specific_event_e2e_failure(monkeypatch): def test_add_custom_logger_callback_to_specific_event_e2e_failure(monkeypatch):
from litellm.integrations.openmeter import OpenMeterLogger from litellm.integrations.openmeter import OpenMeterLogger

View file

@ -184,3 +184,83 @@ async def test_multiple_includes():
# Verify original config settings remain # Verify original config settings remain
assert config["litellm_settings"]["callbacks"] == ["prometheus"] assert config["litellm_settings"]["callbacks"] == ["prometheus"]
def test_add_callbacks_from_db_config():
"""Test that callbacks are added correctly and duplicates are prevented"""
# Setup
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
proxy_config = ProxyConfig()
# Reset litellm callbacks before test
litellm.success_callback = []
litellm.failure_callback = []
# Test Case 1: Add new callbacks
config_data = {
"litellm_settings": {
"success_callback": ["langfuse", "custom_callback_api"],
"failure_callback": ["langfuse"],
}
}
proxy_config._add_callbacks_from_db_config(config_data)
# 1 instance of LangfusePromptManagement should exist in litellm.success_callback
num_langfuse_instances = sum(
isinstance(callback, LangfusePromptManagement)
for callback in litellm.success_callback
)
assert num_langfuse_instances == 1
assert len(litellm.success_callback) == 2
assert len(litellm.failure_callback) == 1
# Test Case 2: Try adding duplicate callbacks
proxy_config._add_callbacks_from_db_config(config_data)
# Verify no duplicates were added
assert len(litellm.success_callback) == 2
assert len(litellm.failure_callback) == 1
# Cleanup
litellm.success_callback = []
litellm.failure_callback = []
litellm._known_custom_logger_compatible_callbacks = []
def test_add_callbacks_invalid_input():
"""Test handling of invalid input for callbacks"""
proxy_config = ProxyConfig()
# Reset callbacks
litellm.success_callback = []
litellm.failure_callback = []
# Test Case 1: Invalid callback format
config_data = {
"litellm_settings": {
"success_callback": "invalid_string_format", # Should be a list
"failure_callback": 123, # Should be a list
}
}
proxy_config._add_callbacks_from_db_config(config_data)
# Verify no callbacks were added with invalid input
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
# Test Case 2: Missing litellm_settings
config_data = {}
proxy_config._add_callbacks_from_db_config(config_data)
# Verify no callbacks were added
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
# Cleanup
litellm.success_callback = []
litellm.failure_callback = []