From a4a8129a135e08b98e8f2f90ef2c9ab2bf41ecb1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 3 Apr 2024 21:57:19 -0700 Subject: [PATCH] fix(router.py): fix pydantic object logic --- litellm/router.py | 20 ++++++++++++++------ litellm/tests/test_router_get_deployments.py | 4 ++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 0d161f1de2..80066aab2b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2143,12 +2143,17 @@ class Router: import os for model in original_model_list: + _model_name = model.pop("model_name") + _litellm_params = model.pop("litellm_params") + _model_info = model.pop("model_info", {}) deployment = Deployment( - model_name=model["model_name"], - litellm_params=model["litellm_params"], - model_info=model.get("model_info", {}), + **model, + model_name=_model_name, + litellm_params=_litellm_params, + model_info=_model_info, ) - self._add_deployment(deployment=deployment) + + deployment = self._add_deployment(deployment=deployment) model = deployment.to_json(exclude_none=True) @@ -2157,7 +2162,7 @@ class Router: verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}") self.model_names = [m["model_name"] for m in model_list] - def _add_deployment(self, deployment: Deployment): + def _add_deployment(self, deployment: Deployment) -> Deployment: import os #### DEPLOYMENT NAMES INIT ######## @@ -2194,7 +2199,7 @@ class Router: # Check if user is trying to use model_name == "*" # this is a catch all model for their specific api key if deployment.model_name == "*": - self.default_deployment = deployment + self.default_deployment = deployment.to_json(exclude_none=True) # Azure GPT-Vision Enhancements, users can pass os.environ/ data_sources = deployment.litellm_params.get("dataSources", []) @@ -2214,6 +2219,8 @@ class Router: # init OpenAI, Azure clients self.set_client(model=deployment.to_json(exclude_none=True)) + return deployment + def add_deployment(self, deployment: Deployment): # check if deployment already exists @@ -2450,6 +2457,7 @@ class Router: model = litellm.model_alias_map[ model ] # update the model to the actual value if an alias has been passed in + if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: deployment = self.leastbusy_logger.get_available_deployments( model_group=model, healthy_deployments=healthy_deployments diff --git a/litellm/tests/test_router_get_deployments.py b/litellm/tests/test_router_get_deployments.py index 53299575a5..44afea12d1 100644 --- a/litellm/tests/test_router_get_deployments.py +++ b/litellm/tests/test_router_get_deployments.py @@ -181,7 +181,7 @@ def test_weighted_selection_router_tpm_as_router_param(): pytest.fail(f"Error occurred: {e}") -test_weighted_selection_router_tpm_as_router_param() +# test_weighted_selection_router_tpm_as_router_param() def test_weighted_selection_router_rpm_as_router_param(): @@ -433,7 +433,7 @@ def test_usage_based_routing(): selection_counts[response["model"]] += 1 - # print("selection counts", selection_counts) + print("selection counts", selection_counts) total_requests = sum(selection_counts.values())