This commit is contained in:
Krrish Dholakia 2024-05-21 17:24:51 -07:00
parent 1ed4e2a301
commit c989b92801
3 changed files with 127 additions and 25 deletions

View file

@ -376,7 +376,7 @@ class Router:
self.lowesttpm_logger = LowestTPMLoggingHandler( self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache, router_cache=self.cache,
model_list=self.model_list, model_list=self.model_list,
routing_args=routing_strategy_args routing_args=routing_strategy_args,
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
@ -384,7 +384,7 @@ class Router:
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache, router_cache=self.cache,
model_list=self.model_list, model_list=self.model_list,
routing_args=routing_strategy_args routing_args=routing_strategy_args,
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
@ -3207,7 +3207,7 @@ class Router:
model: str, model: str,
healthy_deployments: List, healthy_deployments: List,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
allowed_model_region: Optional[Literal["eu"]] = None, request_kwargs: Optional[dict] = None,
): ):
""" """
Filter out model in model group, if: Filter out model in model group, if:
@ -3299,7 +3299,11 @@ class Router:
continue continue
## REGION CHECK ## ## REGION CHECK ##
if allowed_model_region is not None: if (
request_kwargs is not None
and request_kwargs.get("allowed_model_region") is not None
and request_kwargs["allowed_model_region"] == "eu"
):
if _litellm_params.get("region_name") is not None and isinstance( if _litellm_params.get("region_name") is not None and isinstance(
_litellm_params["region_name"], str _litellm_params["region_name"], str
): ):
@ -3313,13 +3317,37 @@ class Router:
else: else:
verbose_router_logger.debug( verbose_router_logger.debug(
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format( "Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
model_id, allowed_model_region model_id, request_kwargs.get("allowed_model_region")
) )
) )
# filter out since region unknown, and user wants to filter for specific region # filter out since region unknown, and user wants to filter for specific region
invalid_model_indices.append(idx) invalid_model_indices.append(idx)
continue continue
## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_object' param
if request_kwargs is not None and litellm.drop_params == False:
# get supported params
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
)
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if supported_openai_params is None:
continue
else:
# check the non-default openai params in request kwargs
non_default_params = litellm.utils.get_non_default_params(
passed_params=request_kwargs
)
# check if all params are supported
for k, v in non_default_params.items():
if k not in supported_openai_params:
# if not -> invalid model
invalid_model_indices.append(idx)
if len(invalid_model_indices) == len(_returned_deployments): if len(invalid_model_indices) == len(_returned_deployments):
""" """
- no healthy deployments available b/c context window checks or rate limit error - no healthy deployments available b/c context window checks or rate limit error
@ -3469,25 +3497,14 @@ class Router:
if request_kwargs is not None if request_kwargs is not None
else None else None
) )
if self.enable_pre_call_checks and messages is not None: if self.enable_pre_call_checks and messages is not None:
if _allowed_model_region == "eu": healthy_deployments = self._pre_call_checks(
healthy_deployments = self._pre_call_checks( model=model,
model=model, healthy_deployments=healthy_deployments,
healthy_deployments=healthy_deployments, messages=messages,
messages=messages, request_kwargs=request_kwargs,
allowed_model_region=_allowed_model_region, )
)
else:
verbose_router_logger.debug(
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
_allowed_model_region
)
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
)
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
if _allowed_model_region is None: if _allowed_model_region is None:

View file

@ -689,6 +689,44 @@ def test_router_context_window_check_pre_call_check_out_group():
pytest.fail(f"Got unexpected exception on router! - {str(e)}") pytest.fail(f"Got unexpected exception on router! - {str(e)}")
def test_filter_invalid_params_pre_call_check():
"""
- gpt-3.5-turbo supports 'response_object'
- gpt-3.5-turbo-16k doesn't support 'response_object'
run pre-call check -> assert returned list doesn't include gpt-3.5-turbo-16k
"""
try:
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list, set_verbose=True, enable_pre_call_checks=True, num_retries=0) # type: ignore
filtered_deployments = router._pre_call_checks(
model="gpt-3.5-turbo",
healthy_deployments=model_list,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
request_kwargs={"response_format": {"type": "json_object"}},
)
assert len(filtered_deployments) == 1
except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
@pytest.mark.parametrize("allowed_model_region", ["eu", None]) @pytest.mark.parametrize("allowed_model_region", ["eu", None])
def test_router_region_pre_call_check(allowed_model_region): def test_router_region_pre_call_check(allowed_model_region):
""" """

View file

@ -5811,7 +5811,7 @@ def get_optional_params(
"mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1",
]: ]:
supported_params += [ supported_params += [ # type: ignore
"functions", "functions",
"function_call", "function_call",
"tools", "tools",
@ -6061,6 +6061,47 @@ def get_optional_params(
return optional_params return optional_params
def get_non_default_params(passed_params: dict) -> dict:
default_params = {
"functions": None,
"function_call": None,
"temperature": None,
"top_p": None,
"n": None,
"stream": None,
"stream_options": None,
"stop": None,
"max_tokens": None,
"presence_penalty": None,
"frequency_penalty": None,
"logit_bias": None,
"user": None,
"model": None,
"custom_llm_provider": "",
"response_format": None,
"seed": None,
"tools": None,
"tool_choice": None,
"max_retries": None,
"logprobs": None,
"top_logprobs": None,
"extra_headers": None,
}
# filter out those parameters that were passed with non-default values
non_default_params = {
k: v
for k, v in passed_params.items()
if (
k != "model"
and k != "custom_llm_provider"
and k in default_params
and v != default_params[k]
)
}
return non_default_params
def calculate_max_parallel_requests( def calculate_max_parallel_requests(
max_parallel_requests: Optional[int], max_parallel_requests: Optional[int],
rpm: Optional[int], rpm: Optional[int],
@ -6287,7 +6328,7 @@ def get_first_chars_messages(kwargs: dict) -> str:
return "" return ""
def get_supported_openai_params(model: str, custom_llm_provider: str): def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optional[list]:
""" """
Returns the supported openai params for a given model + provider Returns the supported openai params for a given model + provider
@ -6295,6 +6336,10 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
``` ```
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock") get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
``` ```
Returns:
- List if custom_llm_provider is mapped
- None if unmapped
""" """
if custom_llm_provider == "bedrock": if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"): if model.startswith("anthropic.claude-3"):
@ -6534,6 +6579,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXAIConfig().get_supported_openai_params() return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
return None
def get_formatted_prompt( def get_formatted_prompt(
data: dict, data: dict,