mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix fallbacks dont recurse on the same fallback
This commit is contained in:
parent
fb16ff2335
commit
078fe97053
3 changed files with 122 additions and 87 deletions
|
@ -62,6 +62,7 @@ from litellm.router_utils.cooldown_callbacks import router_cooldown_handler
|
|||
from litellm.router_utils.fallback_event_handlers import (
|
||||
log_failure_fallback_event,
|
||||
log_success_fallback_event,
|
||||
run_async_fallback,
|
||||
)
|
||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||
from litellm.scheduler import FlowItem, Scheduler
|
||||
|
@ -2383,34 +2384,16 @@ class Router:
|
|||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await self.async_function_with_retries(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"Successful fallback b/w models."
|
||||
)
|
||||
# callback for successfull_fallback_event():
|
||||
await log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
response = await run_async_fallback(
|
||||
*args,
|
||||
litellm_router=self,
|
||||
fallback_model_group=fallback_model_group,
|
||||
original_model_group=original_model_group,
|
||||
original_exception=original_exception,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await log_failure_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
pass
|
||||
else:
|
||||
error_message = "model={}. context_window_fallbacks={}. fallbacks={}.\n\nSet 'context_window_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format(
|
||||
model_group, context_window_fallbacks, fallbacks
|
||||
|
@ -2436,33 +2419,15 @@ class Router:
|
|||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await self.async_function_with_retries(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"Successful fallback b/w models."
|
||||
)
|
||||
# callback for successfull_fallback_event():
|
||||
await log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
await log_failure_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
pass
|
||||
response = await run_async_fallback(
|
||||
*args,
|
||||
litellm_router=self,
|
||||
fallback_model_group=fallback_model_group,
|
||||
original_model_group=original_model_group,
|
||||
original_exception=original_exception,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
error_message = "model={}. content_policy_fallback={}. fallbacks={}.\n\nSet 'content_policy_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format(
|
||||
model_group, content_policy_fallbacks, fallbacks
|
||||
|
@ -2502,39 +2467,16 @@ class Router:
|
|||
if hasattr(original_exception, "message"):
|
||||
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
raise original_exception
|
||||
for mg in fallback_model_group:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
verbose_router_logger.info(
|
||||
f"Falling back to model_group = {mg}"
|
||||
)
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await self.async_function_with_fallbacks(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"Successful fallback b/w models."
|
||||
)
|
||||
# callback for successfull_fallback_event():
|
||||
await log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await log_failure_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
raise e
|
||||
response = await run_async_fallback(
|
||||
*args,
|
||||
litellm_router=self,
|
||||
fallback_model_group=fallback_model_group,
|
||||
original_model_group=original_model_group,
|
||||
original_exception=original_exception,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as new_exception:
|
||||
verbose_router_logger.error(
|
||||
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
||||
|
|
|
@ -1,9 +1,56 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
async def run_async_fallback(
|
||||
litellm_router: LitellmRouter,
|
||||
*args: Tuple[Any],
|
||||
fallback_model_group: List[str],
|
||||
original_model_group: str,
|
||||
original_exception: Exception,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment.
|
||||
"""
|
||||
error_from_fallbacks = original_exception
|
||||
for mg in fallback_model_group:
|
||||
if mg == original_model_group:
|
||||
continue
|
||||
try:
|
||||
# LOGGING
|
||||
kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
|
||||
verbose_router_logger.info(f"Falling back to model_group = {mg}")
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await litellm_router.async_function_with_fallbacks(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info("Successful fallback b/w models.")
|
||||
# callback for successfull_fallback_event():
|
||||
await log_success_fallback_event(
|
||||
original_model_group=original_model_group, kwargs=kwargs
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
error_from_fallbacks = e
|
||||
await log_failure_fallback_event(
|
||||
original_model_group=original_model_group, kwargs=kwargs
|
||||
)
|
||||
raise error_from_fallbacks
|
||||
|
||||
|
||||
async def log_success_fallback_event(original_model_group: str, kwargs: dict):
|
||||
for _callback in litellm.callbacks:
|
||||
|
|
|
@ -1185,3 +1185,49 @@ async def test_router_content_policy_fallbacks(
|
|||
)
|
||||
|
||||
assert response.model == "my-fake-model"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_using_default_fallback(sync_mode):
|
||||
"""
|
||||
Tests Client Side Fallbacks
|
||||
|
||||
User can pass "fallbacks": ["gpt-3.5-turbo"] and this should work
|
||||
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
|
||||
import logging
|
||||
|
||||
from litellm._logging import verbose_logger, verbose_router_logger
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
verbose_router_logger.setLevel(logging.DEBUG)
|
||||
litellm.default_fallbacks = ["very-bad-model"]
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
try:
|
||||
if sync_mode:
|
||||
response = router.completion(
|
||||
model="openai/foo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
else:
|
||||
response = await router.acompletion(
|
||||
model="openai/foo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
print("got response=", response)
|
||||
pytest.fail(f"Expected call to fail we passed model=openai/foo")
|
||||
except Exception as e:
|
||||
print("got exception = ", e)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue