mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
349 lines
14 KiB
Python
349 lines
14 KiB
Python
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_router_logger
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.router_utils.add_retry_fallback_headers import (
|
|
add_fallback_headers_to_response,
|
|
)
|
|
from litellm.types.router import LiteLLMParamsTypedDict
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.router import Router as _Router
|
|
|
|
LitellmRouter = _Router
|
|
else:
|
|
LitellmRouter = Any
|
|
|
|
|
|
def _check_stripped_model_group(model_group: str, fallback_key: str) -> bool:
|
|
"""
|
|
Handles wildcard routing scenario
|
|
|
|
where fallbacks set like:
|
|
[{"gpt-3.5-turbo": ["claude-3-haiku"]}]
|
|
|
|
but model_group is like:
|
|
"openai/gpt-3.5-turbo"
|
|
|
|
Returns:
|
|
- True if the stripped model group == fallback_key
|
|
"""
|
|
for provider in litellm.provider_list:
|
|
if isinstance(provider, Enum):
|
|
_provider = provider.value
|
|
else:
|
|
_provider = provider
|
|
if model_group.startswith(f"{_provider}/"):
|
|
stripped_model_group = model_group.replace(f"{_provider}/", "")
|
|
if stripped_model_group == fallback_key:
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_fallback_model_group(
|
|
fallbacks: List[Any], model_group: str
|
|
) -> Tuple[Optional[List[str]], Optional[int]]:
|
|
"""
|
|
Returns:
|
|
- fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
|
|
- generic_fallback_idx: int of the index of the generic fallback in the fallbacks list.
|
|
|
|
Checks:
|
|
- exact match
|
|
- stripped model group match
|
|
- generic fallback
|
|
"""
|
|
generic_fallback_idx: Optional[int] = None
|
|
stripped_model_fallback: Optional[List[str]] = None
|
|
fallback_model_group: Optional[List[str]] = None
|
|
## check for specific model group-specific fallbacks
|
|
for idx, item in enumerate(fallbacks):
|
|
if isinstance(item, dict):
|
|
if list(item.keys())[0] == model_group: # check exact match
|
|
fallback_model_group = item[model_group]
|
|
break
|
|
elif _check_stripped_model_group(
|
|
model_group=model_group, fallback_key=list(item.keys())[0]
|
|
): # check generic fallback
|
|
stripped_model_fallback = item[list(item.keys())[0]]
|
|
elif list(item.keys())[0] == "*": # check generic fallback
|
|
generic_fallback_idx = idx
|
|
elif isinstance(item, str):
|
|
fallback_model_group = [fallbacks.pop(idx)] # returns single-item list
|
|
## if none, check for generic fallback
|
|
if fallback_model_group is None:
|
|
if stripped_model_fallback is not None:
|
|
fallback_model_group = stripped_model_fallback
|
|
elif generic_fallback_idx is not None:
|
|
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
|
|
|
return fallback_model_group, generic_fallback_idx
|
|
|
|
|
|
async def run_async_fallback(
|
|
*args: Tuple[Any],
|
|
litellm_router: LitellmRouter,
|
|
fallback_model_group: List[str],
|
|
original_model_group: str,
|
|
original_exception: Exception,
|
|
max_fallbacks: int,
|
|
fallback_depth: int,
|
|
**kwargs,
|
|
) -> Any:
|
|
"""
|
|
Loops through all the fallback model groups and calls kwargs["original_function"] with the arguments and keyword arguments provided.
|
|
|
|
If the call is successful, it logs the success and returns the response.
|
|
If the call fails, it logs the failure and continues to the next fallback model group.
|
|
If all fallback model groups fail, it raises the most recent exception.
|
|
|
|
Args:
|
|
litellm_router: The litellm router instance.
|
|
*args: Positional arguments.
|
|
fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
|
|
original_model_group: The original model group. example: "gpt-3.5-turbo"
|
|
original_exception: The original exception.
|
|
**kwargs: Keyword arguments.
|
|
|
|
Returns:
|
|
The response from the successful fallback model group.
|
|
Raises:
|
|
The most recent exception if all fallback model groups fail.
|
|
"""
|
|
|
|
### BASE CASE ### MAX FALLBACK DEPTH REACHED
|
|
if fallback_depth >= max_fallbacks:
|
|
raise original_exception
|
|
|
|
error_from_fallbacks = original_exception
|
|
|
|
# Handle mid-stream fallbacks by preserving already generated content
|
|
is_mid_stream = kwargs.pop("is_mid_stream_fallback", False)
|
|
previous_content = kwargs.pop("previous_content", "")
|
|
|
|
# If this is a mid-stream fallback and we have previous content, prepare messages
|
|
if is_mid_stream and previous_content and "messages" in kwargs:
|
|
messages = kwargs.get("messages", [])
|
|
|
|
if isinstance(messages, list) and len(messages) > 0:
|
|
if previous_content.strip():
|
|
# Check for a system message
|
|
system_msg_idx = None
|
|
for i, msg in enumerate(messages):
|
|
if msg.get("role") == "system":
|
|
system_msg_idx = i
|
|
break
|
|
|
|
continuation_text = f"The following is the beginning of an assistant's response. Continue from where it left off: '{previous_content}'"
|
|
|
|
if system_msg_idx is not None:
|
|
# Append to existing system message
|
|
messages[system_msg_idx]["content"] = messages[system_msg_idx].get("content", "") + "\n\n" + continuation_text
|
|
else:
|
|
# Add a new system message
|
|
messages.insert(0, {"role": "assistant", "content": continuation_text})
|
|
|
|
# Update kwargs with modified messages
|
|
kwargs["messages"] = messages
|
|
# Add to metadata to track this was a mid-stream fallback
|
|
kwargs.setdefault("metadata", {}).update({
|
|
"is_mid_stream_fallback": True,
|
|
"fallback_depth": fallback_depth,
|
|
"previous_content_length": len(previous_content)
|
|
})
|
|
|
|
for mg in fallback_model_group:
|
|
if mg == original_model_group:
|
|
continue
|
|
try:
|
|
# LOGGING
|
|
kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
|
|
verbose_router_logger.info(f"Falling back to model_group = {mg}")
|
|
if isinstance(mg, str):
|
|
kwargs["model"] = mg
|
|
elif isinstance(mg, dict):
|
|
kwargs.update(mg)
|
|
kwargs.setdefault("metadata", {}).update(
|
|
{"model_group": kwargs.get("model", None)}
|
|
) # update model_group used, if fallbacks are done
|
|
fallback_depth = fallback_depth + 1
|
|
kwargs["fallback_depth"] = fallback_depth
|
|
kwargs["max_fallbacks"] = max_fallbacks
|
|
response = await litellm_router.async_function_with_fallbacks(
|
|
*args, **kwargs
|
|
)
|
|
if hasattr(response, "_hidden_params"):
|
|
response._hidden_params.setdefault("metadata", {})["mid_stream_fallback"] = True
|
|
# Also add to additional_headers for header propagation
|
|
response._hidden_params.setdefault("additional_headers", {})["x-litellm-mid-stream-fallback"] = True
|
|
verbose_router_logger.info("Successful fallback b/w models.")
|
|
response = add_fallback_headers_to_response(
|
|
response=response,
|
|
attempted_fallbacks=fallback_depth,
|
|
)
|
|
|
|
# If this was a mid-stream fallback, also add that to response headers
|
|
if is_mid_stream and hasattr(response, "_hidden_params"):
|
|
response._hidden_params.setdefault("additional_headers", {})
|
|
response._hidden_params["additional_headers"]["x-litellm-mid-stream-fallback"] = True
|
|
response._hidden_params["additional_headers"]["x-litellm-previous-content-length"] = len(previous_content)
|
|
|
|
# callback for successfull_fallback_event():
|
|
await log_success_fallback_event(
|
|
original_model_group=original_model_group,
|
|
kwargs=kwargs,
|
|
original_exception=original_exception,
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
error_from_fallbacks = e
|
|
await log_failure_fallback_event(
|
|
original_model_group=original_model_group,
|
|
kwargs=kwargs,
|
|
original_exception=original_exception,
|
|
)
|
|
raise error_from_fallbacks
|
|
|
|
|
|
async def log_success_fallback_event(
|
|
original_model_group: str, kwargs: dict, original_exception: Exception
|
|
):
|
|
"""
|
|
Log a successful fallback event to all registered callbacks.
|
|
|
|
This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed,
|
|
and calls the log_success_fallback_event method on CustomLogger instances.
|
|
|
|
Args:
|
|
original_model_group (str): The original model group before fallback.
|
|
kwargs (dict): kwargs for the request
|
|
|
|
Note:
|
|
Errors during logging are caught and reported but do not interrupt the process.
|
|
"""
|
|
from litellm.litellm_core_utils.litellm_logging import (
|
|
_init_custom_logger_compatible_class,
|
|
)
|
|
|
|
for _callback in litellm.callbacks:
|
|
if isinstance(_callback, CustomLogger) or (
|
|
_callback in litellm._known_custom_logger_compatible_callbacks
|
|
):
|
|
try:
|
|
_callback_custom_logger: Optional[CustomLogger] = None
|
|
if _callback in litellm._known_custom_logger_compatible_callbacks:
|
|
_callback_custom_logger = _init_custom_logger_compatible_class(
|
|
logging_integration=_callback, # type: ignore
|
|
llm_router=None,
|
|
internal_usage_cache=None,
|
|
)
|
|
elif isinstance(_callback, CustomLogger):
|
|
_callback_custom_logger = _callback
|
|
else:
|
|
verbose_router_logger.exception(
|
|
f"{_callback} logger not found / initialized properly"
|
|
)
|
|
continue
|
|
|
|
if _callback_custom_logger is None:
|
|
verbose_router_logger.exception(
|
|
f"{_callback} logger not found / initialized properly, callback is None"
|
|
)
|
|
continue
|
|
|
|
await _callback_custom_logger.log_success_fallback_event(
|
|
original_model_group=original_model_group,
|
|
kwargs=kwargs,
|
|
original_exception=original_exception,
|
|
)
|
|
except Exception as e:
|
|
verbose_router_logger.error(
|
|
f"Error in log_success_fallback_event: {str(e)}"
|
|
)
|
|
|
|
|
|
async def log_failure_fallback_event(
|
|
original_model_group: str, kwargs: dict, original_exception: Exception
|
|
):
|
|
"""
|
|
Log a failed fallback event to all registered callbacks.
|
|
|
|
This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed,
|
|
and calls the log_failure_fallback_event method on CustomLogger instances.
|
|
|
|
Args:
|
|
original_model_group (str): The original model group before fallback.
|
|
kwargs (dict): kwargs for the request
|
|
|
|
Note:
|
|
Errors during logging are caught and reported but do not interrupt the process.
|
|
"""
|
|
from litellm.litellm_core_utils.litellm_logging import (
|
|
_init_custom_logger_compatible_class,
|
|
)
|
|
|
|
for _callback in litellm.callbacks:
|
|
if isinstance(_callback, CustomLogger) or (
|
|
_callback in litellm._known_custom_logger_compatible_callbacks
|
|
):
|
|
try:
|
|
_callback_custom_logger: Optional[CustomLogger] = None
|
|
if _callback in litellm._known_custom_logger_compatible_callbacks:
|
|
_callback_custom_logger = _init_custom_logger_compatible_class(
|
|
logging_integration=_callback, # type: ignore
|
|
llm_router=None,
|
|
internal_usage_cache=None,
|
|
)
|
|
elif isinstance(_callback, CustomLogger):
|
|
_callback_custom_logger = _callback
|
|
else:
|
|
verbose_router_logger.exception(
|
|
f"{_callback} logger not found / initialized properly"
|
|
)
|
|
continue
|
|
|
|
if _callback_custom_logger is None:
|
|
verbose_router_logger.exception(
|
|
f"{_callback} logger not found / initialized properly"
|
|
)
|
|
continue
|
|
|
|
await _callback_custom_logger.log_failure_fallback_event(
|
|
original_model_group=original_model_group,
|
|
kwargs=kwargs,
|
|
original_exception=original_exception,
|
|
)
|
|
except Exception as e:
|
|
verbose_router_logger.error(
|
|
f"Error in log_failure_fallback_event: {str(e)}"
|
|
)
|
|
|
|
|
|
def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool:
|
|
"""
|
|
Checks if the fallbacks list is a list of strings or a list of dictionaries.
|
|
|
|
If
|
|
- List[str]: e.g. ["claude-3-haiku", "openai/o-1"]
|
|
- List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
|
|
|
|
If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format.
|
|
"""
|
|
if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0:
|
|
return False
|
|
if all(isinstance(item, str) for item in fallbacks):
|
|
return True
|
|
elif all(isinstance(item, dict) for item in fallbacks):
|
|
for key in LiteLLMParamsTypedDict.__annotations__.keys():
|
|
if key in fallbacks[0].keys():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def run_non_standard_fallback_format(
|
|
fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str
|
|
):
|
|
pass
|