From 50733363eef74debeeb0d7433450ef4a193523a5 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 27 Nov 2023 18:08:07 -0800 Subject: [PATCH] (feat) use api_base, api_key as model --- litellm/router.py | 47 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 7710f77875..24e2a50f0a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -87,6 +87,7 @@ class Router: self.routing_strategy = routing_strategy self.fallbacks = fallbacks or litellm.fallbacks self.context_window_fallbacks = context_window_fallbacks or litellm.context_window_fallbacks + self.model_exception_map = {} # dict to store model: list exceptions. self.exceptions = {"gpt-3.5": ["API KEY Error", "Rate Limit Error", "good morning error"]} # make Router.chat.completions.create compatible for openai.chat.completions.create self.chat = litellm.Chat(params=default_litellm_params) @@ -171,7 +172,16 @@ class Router: for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params data[k] = v - data["model"] = data["model"][:-14] + + ########## remove -ModelID-XXXX from model ############## + original_model_string = data["model"] + # Find the index of "ModelID" in the string + index_of_model_id = original_model_string.find("-ModelID") + # Remove everything after "-ModelID" if it exists + if index_of_model_id != -1: + data["model"] = original_model_string[:index_of_model_id] + else: + data["model"] = original_model_string self.print_verbose(f"completion model: {data['model']}") return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) except Exception as e: @@ -207,9 +217,15 @@ class Router: for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params data[k] = v - data["model"] = data["model"][:-14] - self.print_verbose(f"acompletion model: {data['model']}") - + ########## remove -ModelID-XXXX from model ############## + original_model_string = data["model"] + # Find the index of "ModelID" in the string + index_of_model_id = original_model_string.find("-ModelID") + # Remove everything after "-ModelID" if it exists + if index_of_model_id != -1: + data["model"] = original_model_string[:index_of_model_id] + else: + data["model"] = original_model_string response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) return response except Exception as e: @@ -514,12 +530,26 @@ class Router: start_time, end_time # start/end time ): try: + exception = kwargs.get("exception", None) + exception_type = type(exception) + exception_status = getattr(exception, 'status_code', "") + exception_cause = getattr(exception, '__cause__', "") + exception_message = getattr(exception, 'message', "") + exception_str = str(exception_type) + "Status: " + str(exception_status) + "Message: " + str(exception_cause) + str(exception_message) + "Full exception" + str(exception) model_name = kwargs.get('model', None) # i.e. gpt35turbo custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure metadata = kwargs.get("litellm_params", {}).get('metadata', None) if metadata: deployment = metadata.get("deployment", None) self._set_cooldown_deployments(deployment) + deployment_exceptions = self.model_exception_map.get(deployment, []) + deployment_exceptions.append(exception_str) + self.model_exception_map[deployment] = deployment_exceptions + self.print_verbose("\nEXCEPTION FOR DEPLOYMENTS\n") + self.print_verbose(self.model_exception_map) + for model in self.model_exception_map: + self.print_verbose(f"Model {model} had {len(self.model_exception_map[model])} exception") + self.print_verbose() if custom_llm_provider: model_name = f"{custom_llm_provider}/{model_name}" @@ -539,7 +569,7 @@ class Router: # cooldown deployment current_fails = self.failed_calls.get_cache(key=deployment) or 0 updated_fails = current_fails + 1 - self.print_verbose(f"updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}") + self.print_verbose(f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}") if updated_fails > self.allowed_fails: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls @@ -731,9 +761,12 @@ class Router: def set_model_list(self, model_list: list): self.model_list = model_list - # we add a 5 digit uuid to each model so load balancing between azure/gpt on api_base1 and api_base2 works + # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works for model in self.model_list: - model["litellm_params"]["model"] += "-ModelID-" + str(random.randint(10000, 99999))[:5] + model_id = "" + for key in model["litellm_params"]: + model_id+= str(model["litellm_params"][key]) + model["litellm_params"]["model"] += "-ModelID-" + model_id self.model_names = [m["model_name"] for m in model_list] def get_model_names(self):