feat(router.py): enable filtering model group by 'allowed_model_region'

This commit is contained in:
Krrish Dholakia 2024-05-08 22:10:17 -07:00
parent db666b01e5
commit 3d18897d69
11 changed files with 417 additions and 35 deletions

View file

@ -32,6 +32,7 @@ from litellm.utils import (
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
_is_region_eu,
)
import copy
from litellm._logging import verbose_router_logger
@ -1999,7 +2000,11 @@ class Router:
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and api_key.startswith("os.environ/"):
if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
@ -2023,6 +2028,7 @@ class Router:
if (
is_azure_ai_studio_model == True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
@ -2103,13 +2109,14 @@ class Router:
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if "azure" in model_name:
if api_base is None:
if "azure" in model_name and isinstance(api_key, str):
if api_base is None or not isinstance(api_base, str):
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
)
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
@ -2532,7 +2539,7 @@ class Router:
self.default_deployment = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", [])
data_sources = deployment.litellm_params.get("dataSources", []) or []
for data_source in data_sources:
params = data_source.get("parameters", {})
@ -2549,6 +2556,22 @@ class Router:
# init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model)
try:
if "azure" in deployment.litellm_params.model:
region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None
)
deployment.litellm_params.region_name = region
except Exception as e:
verbose_router_logger.error(
"Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e)
)
)
pass # [NON-BLOCKING]
return deployment
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
@ -2820,14 +2843,17 @@ class Router:
model: str,
healthy_deployments: List,
messages: List[Dict[str, str]],
allowed_model_region: Optional[Literal["eu"]] = None,
):
"""
Filter out model in model group, if:
- model context window < message length
- filter models above rpm limits
- if region given, filter out models not in that region / unknown region
- [TODO] function call and model doesn't support function calling
"""
verbose_router_logger.debug(
f"Starting Pre-call checks for deployments in model={model}"
)
@ -2878,9 +2904,9 @@ class Router:
except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
## RPM CHECK ##
_litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "")
## RPM CHECK ##
### get local router cache ###
current_request_cache_local = (
self.cache.get_cache(key=model_id, local_only=True) or 0
@ -2908,6 +2934,28 @@ class Router:
_rate_limit_error = True
continue
## REGION CHECK ##
if allowed_model_region is not None:
if _litellm_params.get("region_name") is not None and isinstance(
_litellm_params["region_name"], str
):
# check if in allowed_model_region
if (
_is_region_eu(model_region=_litellm_params["region_name"])
== False
):
invalid_model_indices.append(idx)
continue
else:
verbose_router_logger.debug(
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
model_id, allowed_model_region
)
)
# filter out since region unknown, and user wants to filter for specific region
invalid_model_indices.append(idx)
continue
if len(invalid_model_indices) == len(_returned_deployments):
"""
- no healthy deployments available b/c context window checks or rate limit error
@ -3047,10 +3095,31 @@ class Router:
# filter pre-call checks
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model, healthy_deployments=healthy_deployments, messages=messages
_allowed_model_region = (
request_kwargs.get("allowed_model_region")
if request_kwargs is not None
else None
)
if _allowed_model_region == "eu":
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
allowed_model_region=_allowed_model_region,
)
else:
verbose_router_logger.debug(
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
_allowed_model_region
)
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
)
if len(healthy_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}"