Merge branch 'main' into litellm_bedrock_command_r_support

This commit is contained in:
Krish Dholakia 2024-05-11 21:24:42 -07:00 committed by GitHub
commit 1d651c6049
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
82 changed files with 3661 additions and 605 deletions

View file

@ -48,6 +48,7 @@ from litellm.types.router import (
AlertingConfig,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
class Router:
@ -605,6 +606,33 @@ class Router:
self.fail_calls[model_name] += 1
raise e
async def abatch_completion(
self, models: List[str], messages: List[Dict[str, str]], **kwargs
):
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(
_async_completion_no_exceptions(
model=model, messages=messages, **kwargs
)
)
response = await asyncio.gather(*_tasks)
return response
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model
@ -1480,26 +1508,29 @@ class Router:
except Exception as e:
original_exception = e
"""
- Check if available deployments - 'get_healthy_deployments() -> List`
- if no, Check if available fallbacks - `is_fallback(model_group: str, exception) -> bool`
- if no, back-off and retry up till num_retries - `_router_should_retry -> float`
Retry Logic
"""
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(original_exception, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
) or (
isinstance(original_exception, openai.RateLimitError)
and fallbacks is not None
):
raise original_exception
### RETRY
_healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
)
_timeout = self._router_should_retry(
# raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
# sleeps for the length of the timeout
await asyncio.sleep(_timeout)
if (
@ -1533,10 +1564,14 @@ class Router:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
_timeout = self._router_should_retry(
_healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
)
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
await asyncio.sleep(_timeout)
try:
@ -1545,6 +1580,40 @@ class Router:
pass
raise original_exception
def should_retry_this_error(
self,
error: Exception,
healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None,
):
"""
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
2. raise an exception for RateLimitError if
- there are no fallbacks
- there are no healthy deployments in the same model group
"""
_num_healthy_deployments = 0
if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is None
):
raise error
# Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError) or isinstance(
error, openai.AuthenticationError
):
if _num_healthy_deployments <= 0:
raise error
return True
def function_with_fallbacks(self, *args, **kwargs):
"""
Try calling the function_with_retries
@ -1633,12 +1702,27 @@ class Router:
raise e
raise original_exception
def _router_should_retry(
self, e: Exception, remaining_retries: int, num_retries: int
def _time_to_sleep_before_retry(
self,
e: Exception,
remaining_retries: int,
num_retries: int,
healthy_deployments: Optional[List] = None,
) -> Union[int, float]:
"""
Calculate back-off, then retry
It should instantly retry only when:
1. there are healthy deployments in the same model group
2. there are fallbacks for the completion call
"""
if (
healthy_deployments is not None
and isinstance(healthy_deployments, list)
and len(healthy_deployments) > 0
):
return 0
if hasattr(e, "response") and hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
@ -1675,23 +1759,29 @@ class Router:
except Exception as e:
original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
if (
isinstance(original_exception, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
) or (
isinstance(original_exception, openai.RateLimitError)
and fallbacks is not None
):
raise original_exception
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY
_timeout = self._router_should_retry(
_healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
)
# raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
time.sleep(_timeout)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
@ -1705,11 +1795,15 @@ class Router:
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
_healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
)
remaining_retries = num_retries - current_attempt
_timeout = self._router_should_retry(
_timeout = self._time_to_sleep_before_retry(
e=e,
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
time.sleep(_timeout)
raise original_exception
@ -1912,6 +2006,47 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = self._get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
async def _async_get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = await self._async_get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
def routing_strategy_pre_call_checks(self, deployment: dict):
"""
Mimics 'async_routing_strategy_pre_call_checks'
@ -2120,6 +2255,10 @@ class Router:
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
)
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None:
api_version = "2023-07-01-preview"
@ -2131,6 +2270,7 @@ class Router:
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -2155,6 +2295,7 @@ class Router:
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -2179,6 +2320,7 @@ class Router:
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
@ -2203,6 +2345,7 @@ class Router:
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
@ -2235,6 +2378,7 @@ class Router:
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
}
from litellm.llms.azure import select_azure_base_url_or_endpoint
@ -2334,7 +2478,7 @@ class Router:
) # cache for 1 hr
else:
_api_key = api_key
_api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
@ -2562,23 +2706,25 @@ class Router:
# init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model)
_auto_infer_region = os.environ.get("AUTO_INFER_REGION", False)
if _auto_infer_region == True or _auto_infer_region == "True":
# set region (if azure model) ## PREVIEW FEATURE ##
if litellm.enable_preview_features == True:
print("Auto inferring region") # noqa
"""
Hiding behind a feature flag
When there is a large amount of LLM deployments this makes startup times blow up
"""
try:
if "azure" in deployment.litellm_params.model:
if (
"azure" in deployment.litellm_params.model
and deployment.litellm_params.region_name is None
):
region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None
)
deployment.litellm_params.region_name = region
except Exception as e:
verbose_router_logger.error(
verbose_router_logger.debug(
"Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e)
)
@ -2956,7 +3102,7 @@ class Router:
):
# check if in allowed_model_region
if (
_is_region_eu(model_region=_litellm_params["region_name"])
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
== False
):
invalid_model_indices.append(idx)