mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(testing) Router add testing coverage (#6253)
* test: add more router code coverage * test: additional router testing coverage * fix: fix linting error * test: fix tests for ci/cd * test: fix test * test: handle flaky tests --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
b72a47d092
commit
dee6de0105
7 changed files with 706 additions and 106 deletions
|
@ -861,11 +861,23 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
||||
"""
|
||||
Adds default litellm params to kwargs, if set.
|
||||
"""
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs and v is not None
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
|
||||
"""
|
||||
Adds selected deployment, model_info and api_base to kwargs["metadata"]
|
||||
|
||||
This is used in litellm logging callbacks
|
||||
2 jobs:
|
||||
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
|
||||
- Adds default litellm params to kwargs, if set.
|
||||
"""
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
|
@ -875,13 +887,7 @@ class Router:
|
|||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs and v is not None
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||
|
||||
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
||||
"""
|
||||
|
@ -910,6 +916,7 @@ class Router:
|
|||
return model_client
|
||||
|
||||
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
|
||||
"""Helper to get timeout from kwargs or deployment params"""
|
||||
timeout = (
|
||||
data.get(
|
||||
"timeout", None
|
||||
|
@ -3414,11 +3421,10 @@ class Router:
|
|||
):
|
||||
"""
|
||||
Track remaining tpm/rpm quota for model in model_list
|
||||
|
||||
Currently, only updates TPM usage.
|
||||
"""
|
||||
try:
|
||||
"""
|
||||
Update TPM usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
|
@ -3459,6 +3465,8 @@ class Router:
|
|||
deployment_id=id,
|
||||
)
|
||||
|
||||
return tpm_key
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
"litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format(
|
||||
|
@ -3473,7 +3481,14 @@ class Router:
|
|||
completion_response, # response from completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Tracks the number of successes for a deployment in the current minute (using in-memory cache)
|
||||
|
||||
Returns:
|
||||
- key: str - The key used to increment the cache
|
||||
- None: if no key is found
|
||||
"""
|
||||
id = None
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
|
@ -3482,15 +3497,18 @@ class Router:
|
|||
model_info = kwargs["litellm_params"].get("model_info", {}) or {}
|
||||
id = model_info.get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
return None
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
if id is not None:
|
||||
increment_deployment_successes_for_current_minute(
|
||||
key = increment_deployment_successes_for_current_minute(
|
||||
litellm_router_instance=self,
|
||||
deployment_id=id,
|
||||
)
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
def deployment_callback_on_failure(
|
||||
self,
|
||||
|
@ -3498,15 +3516,19 @@ class Router:
|
|||
completion_response, # response from completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
2 jobs:
|
||||
- Tracks the number of failures for a deployment in the current minute (using in-memory cache)
|
||||
- Puts the deployment in cooldown if it exceeds the allowed fails / minute
|
||||
|
||||
Returns:
|
||||
- True if the deployment should be put in cooldown
|
||||
- False if the deployment should not be put in cooldown
|
||||
"""
|
||||
try:
|
||||
exception = kwargs.get("exception", None)
|
||||
exception_status = getattr(exception, "status_code", "")
|
||||
model_name = kwargs.get("model", None) # i.e. gpt35turbo
|
||||
custom_llm_provider = kwargs.get("litellm_params", {}).get(
|
||||
"custom_llm_provider", None
|
||||
) # i.e. azure
|
||||
kwargs.get("litellm_params", {}).get("metadata", None)
|
||||
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
||||
|
||||
exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers(
|
||||
|
@ -3535,15 +3557,17 @@ class Router:
|
|||
litellm_router_instance=self,
|
||||
deployment_id=deployment_id,
|
||||
)
|
||||
_set_cooldown_deployments(
|
||||
result = _set_cooldown_deployments(
|
||||
litellm_router_instance=self,
|
||||
exception_status=exception_status,
|
||||
original_exception=exception,
|
||||
deployment=deployment_id,
|
||||
time_to_cooldown=_time_to_cooldown,
|
||||
) # setting deployment_id in cooldown deployments
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
|
||||
return result
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -3582,9 +3606,12 @@ class Router:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _update_usage(self, deployment_id: str):
|
||||
def _update_usage(self, deployment_id: str) -> int:
|
||||
"""
|
||||
Update deployment rpm for that minute
|
||||
|
||||
Returns:
|
||||
- int: request count
|
||||
"""
|
||||
rpm_key = deployment_id
|
||||
|
||||
|
@ -3600,6 +3627,8 @@ class Router:
|
|||
key=rpm_key, value=request_count, local_only=True
|
||||
) # don't change existing ttl
|
||||
|
||||
return request_count
|
||||
|
||||
def _is_cooldown_required(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -3778,7 +3807,7 @@ class Router:
|
|||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
try:
|
||||
_ = await _callback.async_pre_call_check(deployment)
|
||||
await _callback.async_pre_call_check(deployment)
|
||||
except litellm.RateLimitError as e:
|
||||
## LOG FAILURE EVENT
|
||||
if logging_obj is not None:
|
||||
|
@ -3848,10 +3877,23 @@ class Router:
|
|||
return hash_object.hexdigest()
|
||||
|
||||
def _create_deployment(
|
||||
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict
|
||||
):
|
||||
self,
|
||||
deployment_info: dict,
|
||||
_model_name: str,
|
||||
_litellm_params: dict,
|
||||
_model_info: dict,
|
||||
) -> Optional[Deployment]:
|
||||
"""
|
||||
Create a deployment object and add it to the model list
|
||||
|
||||
If the deployment is not active for the current environment, it is ignored
|
||||
|
||||
Returns:
|
||||
- Deployment: The deployment object
|
||||
- None: If the deployment is not active for the current environment (if 'supported_environments' is set in litellm_params)
|
||||
"""
|
||||
deployment = Deployment(
|
||||
**model,
|
||||
**deployment_info,
|
||||
model_name=_model_name,
|
||||
litellm_params=LiteLLM_Params(**_litellm_params),
|
||||
model_info=_model_info,
|
||||
|
@ -3870,18 +3912,18 @@ class Router:
|
|||
)
|
||||
|
||||
## Check if LLM Deployment is allowed for this deployment
|
||||
if deployment.model_info and "supported_environments" in deployment.model_info:
|
||||
if (
|
||||
self.deployment_is_active_for_environment(deployment=deployment)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
if self.deployment_is_active_for_environment(deployment=deployment) is not True:
|
||||
verbose_router_logger.warning(
|
||||
f"Ignoring deployment {deployment.model_name} as it is not active for environment {deployment.model_info['supported_environments']}"
|
||||
)
|
||||
return None
|
||||
|
||||
deployment = self._add_deployment(deployment=deployment)
|
||||
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
self.model_list.append(model)
|
||||
return deployment
|
||||
|
||||
def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
|
||||
"""
|
||||
|
@ -3896,6 +3938,12 @@ class Router:
|
|||
- ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values
|
||||
- ValueError: If supported_environments is not set in model_info or not one of the valid values
|
||||
"""
|
||||
if (
|
||||
deployment.model_info is None
|
||||
or "supported_environments" not in deployment.model_info
|
||||
or deployment.model_info["supported_environments"] is None
|
||||
):
|
||||
return True
|
||||
litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT")
|
||||
if litellm_environment is None:
|
||||
raise ValueError(
|
||||
|
@ -3913,7 +3961,6 @@ class Router:
|
|||
f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}"
|
||||
)
|
||||
|
||||
# validate litellm_environment is one of LiteLLMEnvironment
|
||||
if litellm_environment in deployment.model_info["supported_environments"]:
|
||||
return True
|
||||
return False
|
||||
|
@ -3946,14 +3993,14 @@ class Router:
|
|||
for org in _litellm_params["organization"]:
|
||||
_litellm_params["organization"] = org
|
||||
self._create_deployment(
|
||||
model=model,
|
||||
deployment_info=model,
|
||||
_model_name=_model_name,
|
||||
_litellm_params=_litellm_params,
|
||||
_model_info=_model_info,
|
||||
)
|
||||
else:
|
||||
self._create_deployment(
|
||||
model=model,
|
||||
deployment_info=model,
|
||||
_model_name=_model_name,
|
||||
_litellm_params=_litellm_params,
|
||||
_model_info=_model_info,
|
||||
|
@ -4118,9 +4165,9 @@ class Router:
|
|||
|
||||
if removal_idx is not None:
|
||||
self.model_list.pop(removal_idx)
|
||||
else:
|
||||
# if the model_id is not in router
|
||||
self.add_deployment(deployment=deployment)
|
||||
|
||||
# if the model_id is not in router
|
||||
self.add_deployment(deployment=deployment)
|
||||
return deployment
|
||||
|
||||
def delete_deployment(self, id: str) -> Optional[Deployment]:
|
||||
|
@ -4628,16 +4675,13 @@ class Router:
|
|||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if self.access_groups:
|
||||
return self.access_groups
|
||||
|
||||
if self.model_list:
|
||||
for m in self.model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
# set access groups
|
||||
self.access_groups = access_groups
|
||||
|
||||
return access_groups
|
||||
|
||||
def get_settings(self):
|
||||
|
@ -4672,6 +4716,9 @@ class Router:
|
|||
return _settings_to_return
|
||||
|
||||
def update_settings(self, **kwargs):
|
||||
"""
|
||||
Update the router settings.
|
||||
"""
|
||||
# only the following settings are allowed to be configured
|
||||
_allowed_settings = [
|
||||
"routing_strategy_args",
|
||||
|
@ -5367,66 +5414,16 @@ class Router:
|
|||
return healthy_deployments
|
||||
|
||||
def _track_deployment_metrics(self, deployment, response=None):
|
||||
"""
|
||||
Tracks successful requests rpm usage.
|
||||
"""
|
||||
try:
|
||||
litellm_params = deployment["litellm_params"]
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
model = litellm_params.get("model", "")
|
||||
|
||||
model_id = deployment.get("model_info", {}).get("id", None)
|
||||
if response is None:
|
||||
|
||||
# update self.deployment_stats
|
||||
if model_id is not None:
|
||||
self._update_usage(model_id) # update in-memory cache for tracking
|
||||
if model_id in self.deployment_stats:
|
||||
# only update num_requests
|
||||
self.deployment_stats[model_id]["num_requests"] += 1
|
||||
else:
|
||||
self.deployment_stats[model_id] = {
|
||||
"api_base": api_base,
|
||||
"model": model,
|
||||
"num_requests": 1,
|
||||
}
|
||||
else:
|
||||
# check response_ms and update num_successes
|
||||
if isinstance(response, dict):
|
||||
response_ms = response.get("_response_ms", 0)
|
||||
else:
|
||||
response_ms = 0
|
||||
if model_id is not None:
|
||||
if model_id in self.deployment_stats:
|
||||
# check if avg_latency exists
|
||||
if "avg_latency" in self.deployment_stats[model_id]:
|
||||
# update avg_latency
|
||||
self.deployment_stats[model_id]["avg_latency"] = (
|
||||
self.deployment_stats[model_id]["avg_latency"]
|
||||
+ response_ms
|
||||
) / self.deployment_stats[model_id]["num_successes"]
|
||||
else:
|
||||
self.deployment_stats[model_id]["avg_latency"] = response_ms
|
||||
|
||||
# check if num_successes exists
|
||||
if "num_successes" in self.deployment_stats[model_id]:
|
||||
self.deployment_stats[model_id]["num_successes"] += 1
|
||||
else:
|
||||
self.deployment_stats[model_id]["num_successes"] = 1
|
||||
else:
|
||||
self.deployment_stats[model_id] = {
|
||||
"api_base": api_base,
|
||||
"model": model,
|
||||
"num_successes": 1,
|
||||
"avg_latency": response_ms,
|
||||
}
|
||||
if self.set_verbose is True and self.debug_level == "DEBUG":
|
||||
from pprint import pformat
|
||||
|
||||
# Assuming self.deployment_stats is your dictionary
|
||||
formatted_stats = pformat(self.deployment_stats)
|
||||
|
||||
# Assuming verbose_router_logger is your logger
|
||||
verbose_router_logger.info(
|
||||
"self.deployment_stats: \n%s", formatted_stats
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")
|
||||
|
||||
|
@ -5442,6 +5439,7 @@ class Router:
|
|||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
retry_policy: Optional[RetryPolicy] = self.retry_policy
|
||||
|
||||
if (
|
||||
self.model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
|
@ -5540,7 +5538,9 @@ class Router:
|
|||
litellm.success_callback.append(
|
||||
_slack_alerting_logger.response_taking_too_long_callback
|
||||
)
|
||||
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
|
||||
verbose_router_logger.info(
|
||||
"\033[94m\nInitialized Alerting for litellm.Router\033[0m\n"
|
||||
)
|
||||
|
||||
def set_custom_routing_strategy(
|
||||
self, CustomRoutingStrategy: CustomRoutingStrategyBase
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue