mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(router.py): Fixes https://github.com/BerriAI/litellm/issues/3769
This commit is contained in:
parent
1ed4e2a301
commit
c989b92801
3 changed files with 127 additions and 25 deletions
|
@ -376,7 +376,7 @@ class Router:
|
||||||
self.lowesttpm_logger = LowestTPMLoggingHandler(
|
self.lowesttpm_logger = LowestTPMLoggingHandler(
|
||||||
router_cache=self.cache,
|
router_cache=self.cache,
|
||||||
model_list=self.model_list,
|
model_list=self.model_list,
|
||||||
routing_args=routing_strategy_args
|
routing_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||||
|
@ -384,7 +384,7 @@ class Router:
|
||||||
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
|
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
|
||||||
router_cache=self.cache,
|
router_cache=self.cache,
|
||||||
model_list=self.model_list,
|
model_list=self.model_list,
|
||||||
routing_args=routing_strategy_args
|
routing_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
||||||
|
@ -3207,7 +3207,7 @@ class Router:
|
||||||
model: str,
|
model: str,
|
||||||
healthy_deployments: List,
|
healthy_deployments: List,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
allowed_model_region: Optional[Literal["eu"]] = None,
|
request_kwargs: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Filter out model in model group, if:
|
Filter out model in model group, if:
|
||||||
|
@ -3299,7 +3299,11 @@ class Router:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
## REGION CHECK ##
|
## REGION CHECK ##
|
||||||
if allowed_model_region is not None:
|
if (
|
||||||
|
request_kwargs is not None
|
||||||
|
and request_kwargs.get("allowed_model_region") is not None
|
||||||
|
and request_kwargs["allowed_model_region"] == "eu"
|
||||||
|
):
|
||||||
if _litellm_params.get("region_name") is not None and isinstance(
|
if _litellm_params.get("region_name") is not None and isinstance(
|
||||||
_litellm_params["region_name"], str
|
_litellm_params["region_name"], str
|
||||||
):
|
):
|
||||||
|
@ -3313,13 +3317,37 @@ class Router:
|
||||||
else:
|
else:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
|
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
|
||||||
model_id, allowed_model_region
|
model_id, request_kwargs.get("allowed_model_region")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# filter out since region unknown, and user wants to filter for specific region
|
# filter out since region unknown, and user wants to filter for specific region
|
||||||
invalid_model_indices.append(idx)
|
invalid_model_indices.append(idx)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_object' param
|
||||||
|
if request_kwargs is not None and litellm.drop_params == False:
|
||||||
|
# get supported params
|
||||||
|
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
|
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
|
||||||
|
)
|
||||||
|
|
||||||
|
supported_openai_params = litellm.get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if supported_openai_params is None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# check the non-default openai params in request kwargs
|
||||||
|
non_default_params = litellm.utils.get_non_default_params(
|
||||||
|
passed_params=request_kwargs
|
||||||
|
)
|
||||||
|
# check if all params are supported
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k not in supported_openai_params:
|
||||||
|
# if not -> invalid model
|
||||||
|
invalid_model_indices.append(idx)
|
||||||
|
|
||||||
if len(invalid_model_indices) == len(_returned_deployments):
|
if len(invalid_model_indices) == len(_returned_deployments):
|
||||||
"""
|
"""
|
||||||
- no healthy deployments available b/c context window checks or rate limit error
|
- no healthy deployments available b/c context window checks or rate limit error
|
||||||
|
@ -3469,25 +3497,14 @@ class Router:
|
||||||
if request_kwargs is not None
|
if request_kwargs is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_pre_call_checks and messages is not None:
|
if self.enable_pre_call_checks and messages is not None:
|
||||||
if _allowed_model_region == "eu":
|
healthy_deployments = self._pre_call_checks(
|
||||||
healthy_deployments = self._pre_call_checks(
|
model=model,
|
||||||
model=model,
|
healthy_deployments=healthy_deployments,
|
||||||
healthy_deployments=healthy_deployments,
|
messages=messages,
|
||||||
messages=messages,
|
request_kwargs=request_kwargs,
|
||||||
allowed_model_region=_allowed_model_region,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
verbose_router_logger.debug(
|
|
||||||
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
|
|
||||||
_allowed_model_region
|
|
||||||
)
|
|
||||||
)
|
|
||||||
healthy_deployments = self._pre_call_checks(
|
|
||||||
model=model,
|
|
||||||
healthy_deployments=healthy_deployments,
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
if _allowed_model_region is None:
|
if _allowed_model_region is None:
|
||||||
|
|
|
@ -689,6 +689,44 @@ def test_router_context_window_check_pre_call_check_out_group():
|
||||||
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_invalid_params_pre_call_check():
|
||||||
|
"""
|
||||||
|
- gpt-3.5-turbo supports 'response_object'
|
||||||
|
- gpt-3.5-turbo-16k doesn't support 'response_object'
|
||||||
|
|
||||||
|
run pre-call check -> assert returned list doesn't include gpt-3.5-turbo-16k
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list, set_verbose=True, enable_pre_call_checks=True, num_retries=0) # type: ignore
|
||||||
|
|
||||||
|
filtered_deployments = router._pre_call_checks(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
request_kwargs={"response_format": {"type": "json_object"}},
|
||||||
|
)
|
||||||
|
assert len(filtered_deployments) == 1
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("allowed_model_region", ["eu", None])
|
@pytest.mark.parametrize("allowed_model_region", ["eu", None])
|
||||||
def test_router_region_pre_call_check(allowed_model_region):
|
def test_router_region_pre_call_check(allowed_model_region):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5811,7 +5811,7 @@ def get_optional_params(
|
||||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
]:
|
]:
|
||||||
supported_params += [
|
supported_params += [ # type: ignore
|
||||||
"functions",
|
"functions",
|
||||||
"function_call",
|
"function_call",
|
||||||
"tools",
|
"tools",
|
||||||
|
@ -6061,6 +6061,47 @@ def get_optional_params(
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_non_default_params(passed_params: dict) -> dict:
|
||||||
|
default_params = {
|
||||||
|
"functions": None,
|
||||||
|
"function_call": None,
|
||||||
|
"temperature": None,
|
||||||
|
"top_p": None,
|
||||||
|
"n": None,
|
||||||
|
"stream": None,
|
||||||
|
"stream_options": None,
|
||||||
|
"stop": None,
|
||||||
|
"max_tokens": None,
|
||||||
|
"presence_penalty": None,
|
||||||
|
"frequency_penalty": None,
|
||||||
|
"logit_bias": None,
|
||||||
|
"user": None,
|
||||||
|
"model": None,
|
||||||
|
"custom_llm_provider": "",
|
||||||
|
"response_format": None,
|
||||||
|
"seed": None,
|
||||||
|
"tools": None,
|
||||||
|
"tool_choice": None,
|
||||||
|
"max_retries": None,
|
||||||
|
"logprobs": None,
|
||||||
|
"top_logprobs": None,
|
||||||
|
"extra_headers": None,
|
||||||
|
}
|
||||||
|
# filter out those parameters that were passed with non-default values
|
||||||
|
non_default_params = {
|
||||||
|
k: v
|
||||||
|
for k, v in passed_params.items()
|
||||||
|
if (
|
||||||
|
k != "model"
|
||||||
|
and k != "custom_llm_provider"
|
||||||
|
and k in default_params
|
||||||
|
and v != default_params[k]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return non_default_params
|
||||||
|
|
||||||
|
|
||||||
def calculate_max_parallel_requests(
|
def calculate_max_parallel_requests(
|
||||||
max_parallel_requests: Optional[int],
|
max_parallel_requests: Optional[int],
|
||||||
rpm: Optional[int],
|
rpm: Optional[int],
|
||||||
|
@ -6287,7 +6328,7 @@ def get_first_chars_messages(kwargs: dict) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def get_supported_openai_params(model: str, custom_llm_provider: str):
|
def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optional[list]:
|
||||||
"""
|
"""
|
||||||
Returns the supported openai params for a given model + provider
|
Returns the supported openai params for a given model + provider
|
||||||
|
|
||||||
|
@ -6295,6 +6336,10 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
```
|
```
|
||||||
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
|
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- List if custom_llm_provider is mapped
|
||||||
|
- None if unmapped
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
|
@ -6534,6 +6579,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
|
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_formatted_prompt(
|
def get_formatted_prompt(
|
||||||
data: dict,
|
data: dict,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue