mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_bedrock_command_r_support
This commit is contained in:
commit
1d651c6049
82 changed files with 3661 additions and 605 deletions
|
@ -48,6 +48,7 @@ from litellm.types.router import (
|
|||
AlertingConfig,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -605,6 +606,33 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
async def abatch_completion(
|
||||
self, models: List[str], messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
"""
|
||||
try:
|
||||
return await self.acompletion(model=model, messages=messages, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
_tasks = []
|
||||
for model in models:
|
||||
# add each task but if the task fails
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
@ -1480,26 +1508,29 @@ class Router:
|
|||
except Exception as e:
|
||||
original_exception = e
|
||||
"""
|
||||
- Check if available deployments - 'get_healthy_deployments() -> List`
|
||||
- if no, Check if available fallbacks - `is_fallback(model_group: str, exception) -> bool`
|
||||
- if no, back-off and retry up till num_retries - `_router_should_retry -> float`
|
||||
Retry Logic
|
||||
|
||||
"""
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is not None
|
||||
) or (
|
||||
isinstance(original_exception, openai.RateLimitError)
|
||||
and fallbacks is not None
|
||||
):
|
||||
raise original_exception
|
||||
### RETRY
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
|
||||
_timeout = self._router_should_retry(
|
||||
# raises an exception if this error should not be retries
|
||||
self.should_retry_this_error(
|
||||
error=e,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
|
||||
# sleeps for the length of the timeout
|
||||
await asyncio.sleep(_timeout)
|
||||
|
||||
if (
|
||||
|
@ -1533,10 +1564,14 @@ class Router:
|
|||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._router_should_retry(
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
await asyncio.sleep(_timeout)
|
||||
try:
|
||||
|
@ -1545,6 +1580,40 @@ class Router:
|
|||
pass
|
||||
raise original_exception
|
||||
|
||||
def should_retry_this_error(
|
||||
self,
|
||||
error: Exception,
|
||||
healthy_deployments: Optional[List] = None,
|
||||
context_window_fallbacks: Optional[List] = None,
|
||||
):
|
||||
"""
|
||||
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
||||
|
||||
2. raise an exception for RateLimitError if
|
||||
- there are no fallbacks
|
||||
- there are no healthy deployments in the same model group
|
||||
"""
|
||||
|
||||
_num_healthy_deployments = 0
|
||||
if healthy_deployments is not None and isinstance(healthy_deployments, list):
|
||||
_num_healthy_deployments = len(healthy_deployments)
|
||||
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(error, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is None
|
||||
):
|
||||
raise error
|
||||
|
||||
# Error we should only retry if there are other deployments
|
||||
if isinstance(error, openai.RateLimitError) or isinstance(
|
||||
error, openai.AuthenticationError
|
||||
):
|
||||
if _num_healthy_deployments <= 0:
|
||||
raise error
|
||||
|
||||
return True
|
||||
|
||||
def function_with_fallbacks(self, *args, **kwargs):
|
||||
"""
|
||||
Try calling the function_with_retries
|
||||
|
@ -1633,12 +1702,27 @@ class Router:
|
|||
raise e
|
||||
raise original_exception
|
||||
|
||||
def _router_should_retry(
|
||||
self, e: Exception, remaining_retries: int, num_retries: int
|
||||
def _time_to_sleep_before_retry(
|
||||
self,
|
||||
e: Exception,
|
||||
remaining_retries: int,
|
||||
num_retries: int,
|
||||
healthy_deployments: Optional[List] = None,
|
||||
) -> Union[int, float]:
|
||||
"""
|
||||
Calculate back-off, then retry
|
||||
|
||||
It should instantly retry only when:
|
||||
1. there are healthy deployments in the same model group
|
||||
2. there are fallbacks for the completion call
|
||||
"""
|
||||
if (
|
||||
healthy_deployments is not None
|
||||
and isinstance(healthy_deployments, list)
|
||||
and len(healthy_deployments) > 0
|
||||
):
|
||||
return 0
|
||||
|
||||
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(
|
||||
remaining_retries=remaining_retries,
|
||||
|
@ -1675,23 +1759,29 @@ class Router:
|
|||
except Exception as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
if (
|
||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is not None
|
||||
) or (
|
||||
isinstance(original_exception, openai.RateLimitError)
|
||||
and fallbacks is not None
|
||||
):
|
||||
raise original_exception
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
### RETRY
|
||||
_timeout = self._router_should_retry(
|
||||
_healthy_deployments = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
self.should_retry_this_error(
|
||||
error=e,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
context_window_fallbacks=context_window_fallbacks,
|
||||
)
|
||||
|
||||
# decides how long to sleep before retry
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
|
||||
time.sleep(_timeout)
|
||||
for current_attempt in range(num_retries):
|
||||
verbose_router_logger.debug(
|
||||
|
@ -1705,11 +1795,15 @@ class Router:
|
|||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
_healthy_deployments = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._router_should_retry(
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=e,
|
||||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
time.sleep(_timeout)
|
||||
raise original_exception
|
||||
|
@ -1912,6 +2006,47 @@ class Router:
|
|||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
def _get_healthy_deployments(self, model: str):
|
||||
_all_deployments: list = []
|
||||
try:
|
||||
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
||||
model=model,
|
||||
)
|
||||
if type(_all_deployments) == dict:
|
||||
return []
|
||||
except:
|
||||
pass
|
||||
|
||||
unhealthy_deployments = self._get_cooldown_deployments()
|
||||
healthy_deployments: list = []
|
||||
for deployment in _all_deployments:
|
||||
if deployment["model_info"]["id"] in unhealthy_deployments:
|
||||
continue
|
||||
else:
|
||||
healthy_deployments.append(deployment)
|
||||
|
||||
return healthy_deployments
|
||||
|
||||
async def _async_get_healthy_deployments(self, model: str):
|
||||
_all_deployments: list = []
|
||||
try:
|
||||
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
||||
model=model,
|
||||
)
|
||||
if type(_all_deployments) == dict:
|
||||
return []
|
||||
except:
|
||||
pass
|
||||
|
||||
unhealthy_deployments = await self._async_get_cooldown_deployments()
|
||||
healthy_deployments: list = []
|
||||
for deployment in _all_deployments:
|
||||
if deployment["model_info"]["id"] in unhealthy_deployments:
|
||||
continue
|
||||
else:
|
||||
healthy_deployments.append(deployment)
|
||||
return healthy_deployments
|
||||
|
||||
def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
"""
|
||||
Mimics 'async_routing_strategy_pre_call_checks'
|
||||
|
@ -2120,6 +2255,10 @@ class Router:
|
|||
raise ValueError(
|
||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||
)
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
if azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
|
||||
|
@ -2131,6 +2270,7 @@ class Router:
|
|||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -2155,6 +2295,7 @@ class Router:
|
|||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -2179,6 +2320,7 @@ class Router:
|
|||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -2203,6 +2345,7 @@ class Router:
|
|||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -2235,6 +2378,7 @@ class Router:
|
|||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
}
|
||||
from litellm.llms.azure import select_azure_base_url_or_endpoint
|
||||
|
||||
|
@ -2334,7 +2478,7 @@ class Router:
|
|||
) # cache for 1 hr
|
||||
|
||||
else:
|
||||
_api_key = api_key
|
||||
_api_key = api_key # type: ignore
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
|
@ -2562,23 +2706,25 @@ class Router:
|
|||
# init OpenAI, Azure clients
|
||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||
|
||||
# set region (if azure model)
|
||||
_auto_infer_region = os.environ.get("AUTO_INFER_REGION", False)
|
||||
if _auto_infer_region == True or _auto_infer_region == "True":
|
||||
# set region (if azure model) ## PREVIEW FEATURE ##
|
||||
if litellm.enable_preview_features == True:
|
||||
print("Auto inferring region") # noqa
|
||||
"""
|
||||
Hiding behind a feature flag
|
||||
When there is a large amount of LLM deployments this makes startup times blow up
|
||||
"""
|
||||
try:
|
||||
if "azure" in deployment.litellm_params.model:
|
||||
if (
|
||||
"azure" in deployment.litellm_params.model
|
||||
and deployment.litellm_params.region_name is None
|
||||
):
|
||||
region = litellm.utils.get_model_region(
|
||||
litellm_params=deployment.litellm_params, mode=None
|
||||
)
|
||||
|
||||
deployment.litellm_params.region_name = region
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
verbose_router_logger.debug(
|
||||
"Unable to get the region for azure model - {}, {}".format(
|
||||
deployment.litellm_params.model, str(e)
|
||||
)
|
||||
|
@ -2956,7 +3102,7 @@ class Router:
|
|||
):
|
||||
# check if in allowed_model_region
|
||||
if (
|
||||
_is_region_eu(model_region=_litellm_params["region_name"])
|
||||
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
|
||||
== False
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue