(testing) Router add testing coverage (#6253)

* test: add more router code coverage

* test: additional router testing coverage

* fix: fix linting error

* test: fix tests for ci/cd

* test: fix test

* test: handle flaky tests

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2024-10-16 20:02:27 +05:30 committed by GitHub
parent b72a47d092
commit dee6de0105
7 changed files with 706 additions and 106 deletions

View file

@ -861,11 +861,23 @@ class Router:
self.fail_calls[model_name] += 1
raise e
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
"""
Adds default litellm params to kwargs, if set.
"""
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
"""
Adds selected deployment, model_info and api_base to kwargs["metadata"]
This is used in litellm logging callbacks
2 jobs:
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
- Adds default litellm params to kwargs, if set.
"""
kwargs.setdefault("metadata", {}).update(
{
@ -875,13 +887,7 @@ class Router:
}
)
kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
"""
@ -910,6 +916,7 @@ class Router:
return model_client
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
"""Helper to get timeout from kwargs or deployment params"""
timeout = (
data.get(
"timeout", None
@ -3414,11 +3421,10 @@ class Router:
):
"""
Track remaining tpm/rpm quota for model in model_list
Currently, only updates TPM usage.
"""
try:
"""
Update TPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
@ -3459,6 +3465,8 @@ class Router:
deployment_id=id,
)
return tpm_key
except Exception as e:
verbose_router_logger.exception(
"litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format(
@ -3473,7 +3481,14 @@ class Router:
completion_response, # response from completion
start_time,
end_time, # start/end time
):
) -> Optional[str]:
"""
Tracks the number of successes for a deployment in the current minute (using in-memory cache)
Returns:
- key: str - The key used to increment the cache
- None: if no key is found
"""
id = None
if kwargs["litellm_params"].get("metadata") is None:
pass
@ -3482,15 +3497,18 @@ class Router:
model_info = kwargs["litellm_params"].get("model_info", {}) or {}
id = model_info.get("id", None)
if model_group is None or id is None:
return
return None
elif isinstance(id, int):
id = str(id)
if id is not None:
increment_deployment_successes_for_current_minute(
key = increment_deployment_successes_for_current_minute(
litellm_router_instance=self,
deployment_id=id,
)
return key
return None
def deployment_callback_on_failure(
self,
@ -3498,15 +3516,19 @@ class Router:
completion_response, # response from completion
start_time,
end_time, # start/end time
):
) -> bool:
"""
2 jobs:
- Tracks the number of failures for a deployment in the current minute (using in-memory cache)
- Puts the deployment in cooldown if it exceeds the allowed fails / minute
Returns:
- True if the deployment should be put in cooldown
- False if the deployment should not be put in cooldown
"""
try:
exception = kwargs.get("exception", None)
exception_status = getattr(exception, "status_code", "")
model_name = kwargs.get("model", None) # i.e. gpt35turbo
custom_llm_provider = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None
) # i.e. azure
kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers(
@ -3535,15 +3557,17 @@ class Router:
litellm_router_instance=self,
deployment_id=deployment_id,
)
_set_cooldown_deployments(
result = _set_cooldown_deployments(
litellm_router_instance=self,
exception_status=exception_status,
original_exception=exception,
deployment=deployment_id,
time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}"
return result
else:
return False
except Exception as e:
raise e
@ -3582,9 +3606,12 @@ class Router:
except Exception as e:
raise e
def _update_usage(self, deployment_id: str):
def _update_usage(self, deployment_id: str) -> int:
"""
Update deployment rpm for that minute
Returns:
- int: request count
"""
rpm_key = deployment_id
@ -3600,6 +3627,8 @@ class Router:
key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl
return request_count
def _is_cooldown_required(
self,
model_id: str,
@ -3778,7 +3807,7 @@ class Router:
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
try:
_ = await _callback.async_pre_call_check(deployment)
await _callback.async_pre_call_check(deployment)
except litellm.RateLimitError as e:
## LOG FAILURE EVENT
if logging_obj is not None:
@ -3848,10 +3877,23 @@ class Router:
return hash_object.hexdigest()
def _create_deployment(
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict
):
self,
deployment_info: dict,
_model_name: str,
_litellm_params: dict,
_model_info: dict,
) -> Optional[Deployment]:
"""
Create a deployment object and add it to the model list
If the deployment is not active for the current environment, it is ignored
Returns:
- Deployment: The deployment object
- None: If the deployment is not active for the current environment (if 'supported_environments' is set in litellm_params)
"""
deployment = Deployment(
**model,
**deployment_info,
model_name=_model_name,
litellm_params=LiteLLM_Params(**_litellm_params),
model_info=_model_info,
@ -3870,18 +3912,18 @@ class Router:
)
## Check if LLM Deployment is allowed for this deployment
if deployment.model_info and "supported_environments" in deployment.model_info:
if (
self.deployment_is_active_for_environment(deployment=deployment)
is not True
):
return
if self.deployment_is_active_for_environment(deployment=deployment) is not True:
verbose_router_logger.warning(
f"Ignoring deployment {deployment.model_name} as it is not active for environment {deployment.model_info['supported_environments']}"
)
return None
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)
self.model_list.append(model)
return deployment
def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
"""
@ -3896,6 +3938,12 @@ class Router:
- ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values
- ValueError: If supported_environments is not set in model_info or not one of the valid values
"""
if (
deployment.model_info is None
or "supported_environments" not in deployment.model_info
or deployment.model_info["supported_environments"] is None
):
return True
litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT")
if litellm_environment is None:
raise ValueError(
@ -3913,7 +3961,6 @@ class Router:
f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}"
)
# validate litellm_environment is one of LiteLLMEnvironment
if litellm_environment in deployment.model_info["supported_environments"]:
return True
return False
@ -3946,14 +3993,14 @@ class Router:
for org in _litellm_params["organization"]:
_litellm_params["organization"] = org
self._create_deployment(
model=model,
deployment_info=model,
_model_name=_model_name,
_litellm_params=_litellm_params,
_model_info=_model_info,
)
else:
self._create_deployment(
model=model,
deployment_info=model,
_model_name=_model_name,
_litellm_params=_litellm_params,
_model_info=_model_info,
@ -4118,9 +4165,9 @@ class Router:
if removal_idx is not None:
self.model_list.pop(removal_idx)
else:
# if the model_id is not in router
self.add_deployment(deployment=deployment)
# if the model_id is not in router
self.add_deployment(deployment=deployment)
return deployment
def delete_deployment(self, id: str) -> Optional[Deployment]:
@ -4628,16 +4675,13 @@ class Router:
from collections import defaultdict
access_groups = defaultdict(list)
if self.access_groups:
return self.access_groups
if self.model_list:
for m in self.model_list:
for group in m.get("model_info", {}).get("access_groups", []):
model_name = m["model_name"]
access_groups[group].append(model_name)
# set access groups
self.access_groups = access_groups
return access_groups
def get_settings(self):
@ -4672,6 +4716,9 @@ class Router:
return _settings_to_return
def update_settings(self, **kwargs):
"""
Update the router settings.
"""
# only the following settings are allowed to be configured
_allowed_settings = [
"routing_strategy_args",
@ -5367,66 +5414,16 @@ class Router:
return healthy_deployments
def _track_deployment_metrics(self, deployment, response=None):
"""
Tracks successful requests rpm usage.
"""
try:
litellm_params = deployment["litellm_params"]
api_base = litellm_params.get("api_base", "")
model = litellm_params.get("model", "")
model_id = deployment.get("model_info", {}).get("id", None)
if response is None:
# update self.deployment_stats
if model_id is not None:
self._update_usage(model_id) # update in-memory cache for tracking
if model_id in self.deployment_stats:
# only update num_requests
self.deployment_stats[model_id]["num_requests"] += 1
else:
self.deployment_stats[model_id] = {
"api_base": api_base,
"model": model,
"num_requests": 1,
}
else:
# check response_ms and update num_successes
if isinstance(response, dict):
response_ms = response.get("_response_ms", 0)
else:
response_ms = 0
if model_id is not None:
if model_id in self.deployment_stats:
# check if avg_latency exists
if "avg_latency" in self.deployment_stats[model_id]:
# update avg_latency
self.deployment_stats[model_id]["avg_latency"] = (
self.deployment_stats[model_id]["avg_latency"]
+ response_ms
) / self.deployment_stats[model_id]["num_successes"]
else:
self.deployment_stats[model_id]["avg_latency"] = response_ms
# check if num_successes exists
if "num_successes" in self.deployment_stats[model_id]:
self.deployment_stats[model_id]["num_successes"] += 1
else:
self.deployment_stats[model_id]["num_successes"] = 1
else:
self.deployment_stats[model_id] = {
"api_base": api_base,
"model": model,
"num_successes": 1,
"avg_latency": response_ms,
}
if self.set_verbose is True and self.debug_level == "DEBUG":
from pprint import pformat
# Assuming self.deployment_stats is your dictionary
formatted_stats = pformat(self.deployment_stats)
# Assuming verbose_router_logger is your logger
verbose_router_logger.info(
"self.deployment_stats: \n%s", formatted_stats
)
except Exception as e:
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")
@ -5442,6 +5439,7 @@ class Router:
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy: Optional[RetryPolicy] = self.retry_policy
if (
self.model_group_retry_policy is not None
and model_group is not None
@ -5540,7 +5538,9 @@ class Router:
litellm.success_callback.append(
_slack_alerting_logger.response_taking_too_long_callback
)
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
verbose_router_logger.info(
"\033[94m\nInitialized Alerting for litellm.Router\033[0m\n"
)
def set_custom_routing_strategy(
self, CustomRoutingStrategy: CustomRoutingStrategyBase