From 23af756531d0be7d77b72e94402170f7da467d2f Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 29 Nov 2023 17:54:06 -0800 Subject: [PATCH] (feat) router: random pick based on tpm/rpm --- litellm/router.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index 43533df86..220c86373 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -877,6 +877,15 @@ class Router: if key != "api_key": model_id+= str(model["litellm_params"][key]) model["litellm_params"]["model"] += "-ModelID-" + model_id + + ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### + # for get_available_deployment, we use the litellm_param["rpm"] + # in this snippet we also set rpm to be a litellm_param + if model["litellm_params"].get("rpm") is None and model.get("rpm") is not None: + model["litellm_params"]["rpm"] = model.get("rpm") + if model["litellm_params"].get("tpm") is None and model.get("tpm") is not None: + model["litellm_params"]["tpm"] = model.get("tpm") + self.model_names = [m["model_name"] for m in model_list] def get_model_names(self): @@ -927,6 +936,8 @@ class Router: else: raise ValueError("No models available.") elif self.routing_strategy == "simple-shuffle": + # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm + ############## Check if we can do a RPM/TPM based weighted pick ################# rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) if rpm is not None: # use weight-random pick if rpms provided @@ -940,6 +951,22 @@ class Router: self.print_verbose(f"\n selected index, {selected_index}") deployment = healthy_deployments[selected_index] return deployment or deployment[0] + ############## Check if we can do a RPM/TPM based weighted pick ################# + tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) + if tpm is not None: + # use weight-random pick if rpms provided + tpms = [m["litellm_params"].get("tpm") for m in healthy_deployments] + self.print_verbose(f"\ntpms {tpms}") + total_tpm = sum(tpms) + weights = [tpm / total_tpm for tpm in tpms] + self.print_verbose(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(tpms)), weights=weights)[0] + self.print_verbose(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + return deployment or deployment[0] + + ############## No RPM/TPM passed, we do a random pick ################# item = random.choice(healthy_deployments) return item or item[0] elif self.routing_strategy == "latency-based-routing":