This commit is contained in:
Krrish Dholakia 2024-05-21 17:24:51 -07:00
parent 1ed4e2a301
commit c989b92801
3 changed files with 127 additions and 25 deletions

View file

@ -376,7 +376,7 @@ class Router:
self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
@ -384,7 +384,7 @@ class Router:
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
@ -3207,7 +3207,7 @@ class Router:
model: str,
healthy_deployments: List,
messages: List[Dict[str, str]],
allowed_model_region: Optional[Literal["eu"]] = None,
request_kwargs: Optional[dict] = None,
):
"""
Filter out model in model group, if:
@ -3299,7 +3299,11 @@ class Router:
continue
## REGION CHECK ##
if allowed_model_region is not None:
if (
request_kwargs is not None
and request_kwargs.get("allowed_model_region") is not None
and request_kwargs["allowed_model_region"] == "eu"
):
if _litellm_params.get("region_name") is not None and isinstance(
_litellm_params["region_name"], str
):
@ -3313,13 +3317,37 @@ class Router:
else:
verbose_router_logger.debug(
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
model_id, allowed_model_region
model_id, request_kwargs.get("allowed_model_region")
)
)
# filter out since region unknown, and user wants to filter for specific region
invalid_model_indices.append(idx)
continue
## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_object' param
if request_kwargs is not None and litellm.drop_params == False:
# get supported params
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
)
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if supported_openai_params is None:
continue
else:
# check the non-default openai params in request kwargs
non_default_params = litellm.utils.get_non_default_params(
passed_params=request_kwargs
)
# check if all params are supported
for k, v in non_default_params.items():
if k not in supported_openai_params:
# if not -> invalid model
invalid_model_indices.append(idx)
if len(invalid_model_indices) == len(_returned_deployments):
"""
- no healthy deployments available b/c context window checks or rate limit error
@ -3469,25 +3497,14 @@ class Router:
if request_kwargs is not None
else None
)
if self.enable_pre_call_checks and messages is not None:
if _allowed_model_region == "eu":
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
allowed_model_region=_allowed_model_region,
)
else:
verbose_router_logger.debug(
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
_allowed_model_region
)
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
request_kwargs=request_kwargs,
)
if len(healthy_deployments) == 0:
if _allowed_model_region is None: