forked from phoenix/litellm-mirror
Litellm router disable fallbacks (#6743)
* bump: version 1.52.6 → 1.52.7 * feat(router.py): enable dynamically disabling fallbacks Allows for enabling/disabling fallbacks per key * feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key * test: fix test
This commit is contained in:
parent
02b6f69004
commit
68d81f88f9
5 changed files with 122 additions and 34 deletions
|
@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup:
|
||||||
)
|
)
|
||||||
return user_api_key_logged_metadata
|
return user_api_key_logged_metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_key_level_controls(
|
||||||
|
key_metadata: dict, data: dict, _metadata_variable_name: str
|
||||||
|
):
|
||||||
|
data = data.copy()
|
||||||
|
if "cache" in key_metadata:
|
||||||
|
data["cache"] = {}
|
||||||
|
if isinstance(key_metadata["cache"], dict):
|
||||||
|
for k, v in key_metadata["cache"].items():
|
||||||
|
if k in SupportedCacheControls:
|
||||||
|
data["cache"][k] = v
|
||||||
|
|
||||||
|
## KEY-LEVEL SPEND LOGS / TAGS
|
||||||
|
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||||
|
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||||
|
data[_metadata_variable_name]["tags"], list
|
||||||
|
):
|
||||||
|
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
|
||||||
|
else:
|
||||||
|
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
|
||||||
|
if "spend_logs_metadata" in key_metadata and isinstance(
|
||||||
|
key_metadata["spend_logs_metadata"], dict
|
||||||
|
):
|
||||||
|
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
|
||||||
|
data[_metadata_variable_name]["spend_logs_metadata"], dict
|
||||||
|
):
|
||||||
|
for key, value in key_metadata["spend_logs_metadata"].items():
|
||||||
|
if (
|
||||||
|
key not in data[_metadata_variable_name]["spend_logs_metadata"]
|
||||||
|
): # don't override k-v pair sent by request (user request)
|
||||||
|
data[_metadata_variable_name]["spend_logs_metadata"][
|
||||||
|
key
|
||||||
|
] = value
|
||||||
|
else:
|
||||||
|
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
|
||||||
|
"spend_logs_metadata"
|
||||||
|
]
|
||||||
|
|
||||||
|
## KEY-LEVEL DISABLE FALLBACKS
|
||||||
|
if "disable_fallbacks" in key_metadata and isinstance(
|
||||||
|
key_metadata["disable_fallbacks"], bool
|
||||||
|
):
|
||||||
|
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
|
|
||||||
### KEY-LEVEL Controls
|
### KEY-LEVEL Controls
|
||||||
key_metadata = user_api_key_dict.metadata
|
key_metadata = user_api_key_dict.metadata
|
||||||
if "cache" in key_metadata:
|
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
||||||
data["cache"] = {}
|
key_metadata=key_metadata,
|
||||||
if isinstance(key_metadata["cache"], dict):
|
data=data,
|
||||||
for k, v in key_metadata["cache"].items():
|
_metadata_variable_name=_metadata_variable_name,
|
||||||
if k in SupportedCacheControls:
|
)
|
||||||
data["cache"][k] = v
|
|
||||||
|
|
||||||
## KEY-LEVEL SPEND LOGS / TAGS
|
|
||||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
|
||||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
|
||||||
data[_metadata_variable_name]["tags"], list
|
|
||||||
):
|
|
||||||
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
|
|
||||||
else:
|
|
||||||
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
|
|
||||||
if "spend_logs_metadata" in key_metadata and isinstance(
|
|
||||||
key_metadata["spend_logs_metadata"], dict
|
|
||||||
):
|
|
||||||
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
|
|
||||||
data[_metadata_variable_name]["spend_logs_metadata"], dict
|
|
||||||
):
|
|
||||||
for key, value in key_metadata["spend_logs_metadata"].items():
|
|
||||||
if (
|
|
||||||
key not in data[_metadata_variable_name]["spend_logs_metadata"]
|
|
||||||
): # don't override k-v pair sent by request (user request)
|
|
||||||
data[_metadata_variable_name]["spend_logs_metadata"][key] = value
|
|
||||||
else:
|
|
||||||
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
|
|
||||||
"spend_logs_metadata"
|
|
||||||
]
|
|
||||||
|
|
||||||
## TEAM-LEVEL SPEND LOGS/TAGS
|
## TEAM-LEVEL SPEND LOGS/TAGS
|
||||||
team_metadata = user_api_key_dict.team_metadata or {}
|
team_metadata = user_api_key_dict.team_metadata or {}
|
||||||
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
||||||
|
|
|
@ -2610,6 +2610,7 @@ class Router:
|
||||||
If it fails after num_retries, fall back to another model group
|
If it fails after num_retries, fall back to another model group
|
||||||
"""
|
"""
|
||||||
model_group: Optional[str] = kwargs.get("model")
|
model_group: Optional[str] = kwargs.get("model")
|
||||||
|
disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
|
||||||
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
|
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
|
||||||
context_window_fallbacks: Optional[List] = kwargs.get(
|
context_window_fallbacks: Optional[List] = kwargs.get(
|
||||||
"context_window_fallbacks", self.context_window_fallbacks
|
"context_window_fallbacks", self.context_window_fallbacks
|
||||||
|
@ -2637,7 +2638,7 @@ class Router:
|
||||||
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
|
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
|
||||||
fallback_failure_exception_str = ""
|
fallback_failure_exception_str = ""
|
||||||
|
|
||||||
if original_model_group is None:
|
if disable_fallbacks is True or original_model_group is None:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
input_kwargs = {
|
input_kwargs = {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.52.6"
|
version = "1.52.7"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.52.6"
|
version = "1.52.7"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
exc_info.value, litellm.AuthenticationError
|
exc_info.value, litellm.AuthenticationError
|
||||||
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
|
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_disable_fallbacks_dynamically():
|
||||||
|
from litellm.router import run_async_fallback
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "bad-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/my-bad-model",
|
||||||
|
"api_key": "my-bad-api-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "good-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
fallbacks=[{"bad-model": ["good-model"]}],
|
||||||
|
default_fallbacks=["good-model"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router,
|
||||||
|
"log_retry",
|
||||||
|
new=MagicMock(return_value=None),
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
resp = await router.acompletion(
|
||||||
|
model="bad-model",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
disable_fallbacks=True,
|
||||||
|
)
|
||||||
|
print(resp)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_client.assert_not_called()
|
||||||
|
|
|
@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
|
||||||
assert new_data["failure_callback"] == expected_failure_callbacks
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"disable_fallbacks_set",
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||||
|
|
||||||
|
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
|
||||||
|
existing_data = {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
||||||
|
}
|
||||||
|
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
||||||
|
key_metadata=key_metadata,
|
||||||
|
data=existing_data,
|
||||||
|
_metadata_variable_name="metadata",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert data["disable_fallbacks"] == disable_fallbacks_set
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue