From 48a5948081d722e10188bb84d0e7f51e460ecfce Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 4 Apr 2024 14:22:50 -0700 Subject: [PATCH] fix(router.py): handle id being passed in as int --- litellm/router.py | 4 +++- litellm/router_strategy/least_busy.py | 10 ++++++++++ litellm/router_strategy/lowest_latency.py | 4 ++++ litellm/router_strategy/lowest_tpm_rpm.py | 4 ++++ litellm/tests/test_least_busy_routing.py | 2 +- litellm/tests/test_lowest_latency_routing.py | 4 +++- litellm/tests/test_tpm_rpm_routing.py | 5 ++++- 7 files changed, 29 insertions(+), 4 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index facd9b050b..a34d90e8aa 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -37,9 +37,11 @@ class ModelInfo(BaseModel): str ] # Allow id to be optional on input, but it will always be present as a str in the model instance - def __init__(self, id: Optional[str] = None, **params): + def __init__(self, id: Optional[Union[str, int]] = None, **params): if id is None: id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided + elif isinstance(id, int): + id = str(id) super().__init__(id=id, **params) class Config: diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index b2d2983b07..68874cec4a 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -41,6 +41,8 @@ class LeastBusyLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) request_count_api_key = f"{model_group}_request_count" # update cache @@ -67,6 +69,8 @@ class LeastBusyLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) request_count_api_key = f"{model_group}_request_count" # decrement count in cache @@ -95,6 +99,8 @@ class LeastBusyLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) request_count_api_key = f"{model_group}_request_count" # decrement count in cache @@ -124,6 +130,8 @@ class LeastBusyLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) request_count_api_key = f"{model_group}_request_count" # decrement count in cache @@ -152,6 +160,8 @@ class LeastBusyLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) request_count_api_key = f"{model_group}_request_count" # decrement count in cache diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index 57b56e87f8..f5b4e270da 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -57,6 +57,8 @@ class LowestLatencyLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) # ------------ # Setup values @@ -139,6 +141,8 @@ class LowestLatencyLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) # ------------ # Setup values diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index 1e1e6df98e..0437c2affc 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -39,6 +39,8 @@ class LowestTPMLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) total_tokens = response_obj["usage"]["total_tokens"] @@ -87,6 +89,8 @@ class LowestTPMLoggingHandler(CustomLogger): id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if model_group is None or id is None: return + elif isinstance(id, int): + id = str(id) total_tokens = response_obj["usage"]["total_tokens"] diff --git a/litellm/tests/test_least_busy_routing.py b/litellm/tests/test_least_busy_routing.py index dcb72898f5..782d5b343a 100644 --- a/litellm/tests/test_least_busy_routing.py +++ b/litellm/tests/test_least_busy_routing.py @@ -112,7 +112,7 @@ def test_router_get_available_deployments(): deployment = router.get_available_deployment(model=model_group, messages=None) print(f"deployment: {deployment}") - assert deployment["model_info"]["id"] == 1 + assert deployment["model_info"]["id"] == "1" ## run router completion - assert completion event, no change in 'busy'ness once calls are complete diff --git a/litellm/tests/test_lowest_latency_routing.py b/litellm/tests/test_lowest_latency_routing.py index 7d93f0e608..7ef5e5d7a2 100644 --- a/litellm/tests/test_lowest_latency_routing.py +++ b/litellm/tests/test_lowest_latency_routing.py @@ -337,7 +337,9 @@ def test_router_get_available_deployments(): ## CHECK WHAT'S SELECTED ## # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) print(router.get_available_deployment(model="azure-model")) - assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2 + assert ( + router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2" + ) # test_router_get_available_deployments() diff --git a/litellm/tests/test_tpm_rpm_routing.py b/litellm/tests/test_tpm_rpm_routing.py index e7ada5eb84..8fe30cfcc0 100644 --- a/litellm/tests/test_tpm_rpm_routing.py +++ b/litellm/tests/test_tpm_rpm_routing.py @@ -173,6 +173,7 @@ def test_router_get_available_deployments(): num_retries=3, ) # type: ignore + print(f"router id's: {router.get_model_ids()}") ## DEPLOYMENT 1 ## deployment_id = 1 kwargs = { @@ -214,7 +215,9 @@ def test_router_get_available_deployments(): ## CHECK WHAT'S SELECTED ## # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) - assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2 + assert ( + router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2" + ) # test_get_available_deployments()