forked from phoenix/litellm-mirror
(refactor router.py ) - PR 3 - Ensure all functions under 100 lines (#6181)
* add flake 8 check * split up litellm _acompletion * fix get model client * refactor use commong func to add metadata to kwargs * use common func to get timeout * re-use helper to _get_async_model_client * use _handle_mock_testing_rate_limit_error * fix docstring for _handle_mock_testing_rate_limit_error * fix function_with_retries * use helper for mock testing fallbacks * router - use 1 func for simple_shuffle * add doc string for simple_shuffle * use 1 function for filtering cooldown deployments * fix use common helper to _get_fallback_model_group_from_fallbacks
This commit is contained in:
parent
0761a03d05
commit
d0a3052937
5 changed files with 422 additions and 598 deletions
|
@ -25,6 +25,11 @@ repos:
|
|||
exclude: ^litellm/tests/|^litellm/proxy/tests/
|
||||
additional_dependencies: [flake8-print]
|
||||
files: litellm/.*\.py
|
||||
# - id: flake8
|
||||
# name: flake8 (router.py function length)
|
||||
# files: ^litellm/router\.py$
|
||||
# args: [--max-function-length=40]
|
||||
# # additional_dependencies: [flake8-functions]
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: 1.8.0
|
||||
hooks:
|
||||
|
|
File diff suppressed because it is too large
Load diff
96
litellm/router_strategy/simple_shuffle.py
Normal file
96
litellm/router_strategy/simple_shuffle.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Returns a random deployment from the list of healthy deployments.
|
||||
|
||||
If weights are provided, it will return a deployment based on the weights.
|
||||
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def simple_shuffle(
|
||||
llm_router_instance: LitellmRouter,
|
||||
healthy_deployments: Union[List[Any], Dict[Any, Any]],
|
||||
model: str,
|
||||
) -> Dict:
|
||||
"""
|
||||
Returns a random deployment from the list of healthy deployments.
|
||||
|
||||
If weights are provided, it will return a deployment based on the weights.
|
||||
|
||||
If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`.
|
||||
|
||||
Args:
|
||||
llm_router_instance: LitellmRouter instance
|
||||
healthy_deployments: List of healthy deployments
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Dict: A single healthy deployment
|
||||
"""
|
||||
|
||||
############## Check if 'weight' param set for a weighted pick #################
|
||||
weight = healthy_deployments[0].get("litellm_params").get("weight", None)
|
||||
if weight is not None:
|
||||
# use weight-random pick if rpms provided
|
||||
weights = [m["litellm_params"].get("weight", 0) for m in healthy_deployments]
|
||||
verbose_router_logger.debug(f"\nweight {weights}")
|
||||
total_weight = sum(weights)
|
||||
weights = [weight / total_weight for weight in weights]
|
||||
verbose_router_logger.debug(f"\n weights {weights}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(weights)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
|
||||
if rpm is not None:
|
||||
# use weight-random pick if rpms provided
|
||||
rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
|
||||
verbose_router_logger.debug(f"\nrpms {rpms}")
|
||||
total_rpm = sum(rpms)
|
||||
weights = [rpm / total_rpm for rpm in rpms]
|
||||
verbose_router_logger.debug(f"\n weights {weights}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(rpms)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
|
||||
if tpm is not None:
|
||||
# use weight-random pick if rpms provided
|
||||
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
|
||||
verbose_router_logger.debug(f"\ntpms {tpms}")
|
||||
total_tpm = sum(tpms)
|
||||
weights = [tpm / total_tpm for tpm in tpms]
|
||||
verbose_router_logger.debug(f"\n weights {weights}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
|
||||
############## No RPM/TPM passed, we do a random pick #################
|
||||
item = random.choice(healthy_deployments)
|
||||
return item or item[0]
|
|
@ -66,7 +66,7 @@ def run_sync_fallback(
|
|||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Iterate through the model groups and try calling that deployment.
|
||||
Iterate through the fallback model groups and try calling each fallback deployment.
|
||||
"""
|
||||
error_from_fallbacks = original_exception
|
||||
for mg in fallback_model_group:
|
||||
|
|
|
@ -73,6 +73,7 @@ async def test_azure_tenant_id_auth(respx_mock: MockRouter):
|
|||
],
|
||||
created=int(datetime.now().timestamp()),
|
||||
)
|
||||
litellm.set_verbose = True
|
||||
mock_request = respx_mock.post(url__regex=r".*/chat/completions.*").mock(
|
||||
return_value=httpx.Response(200, json=obj.model_dump(mode="json"))
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue