(feat) use api_base, api_key as model

This commit is contained in:
ishaan-jaff 2023-11-27 18:08:07 -08:00
parent 9cef551623
commit 50733363ee

View file

@ -87,6 +87,7 @@ class Router:
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
self.context_window_fallbacks = context_window_fallbacks or litellm.context_window_fallbacks
self.model_exception_map = {} # dict to store model: list exceptions. self.exceptions = {"gpt-3.5": ["API KEY Error", "Rate Limit Error", "good morning error"]}
# make Router.chat.completions.create compatible for openai.chat.completions.create
self.chat = litellm.Chat(params=default_litellm_params)
@ -171,7 +172,16 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
data["model"] = data["model"][:-14]
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
self.print_verbose(f"completion model: {data['model']}")
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
except Exception as e:
@ -207,9 +217,15 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
data["model"] = data["model"][:-14]
self.print_verbose(f"acompletion model: {data['model']}")
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
return response
except Exception as e:
@ -514,12 +530,26 @@ class Router:
start_time, end_time # start/end time
):
try:
exception = kwargs.get("exception", None)
exception_type = type(exception)
exception_status = getattr(exception, 'status_code', "")
exception_cause = getattr(exception, '__cause__', "")
exception_message = getattr(exception, 'message', "")
exception_str = str(exception_type) + "Status: " + str(exception_status) + "Message: " + str(exception_cause) + str(exception_message) + "Full exception" + str(exception)
model_name = kwargs.get('model', None) # i.e. gpt35turbo
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
if metadata:
deployment = metadata.get("deployment", None)
self._set_cooldown_deployments(deployment)
deployment_exceptions = self.model_exception_map.get(deployment, [])
deployment_exceptions.append(exception_str)
self.model_exception_map[deployment] = deployment_exceptions
self.print_verbose("\nEXCEPTION FOR DEPLOYMENTS\n")
self.print_verbose(self.model_exception_map)
for model in self.model_exception_map:
self.print_verbose(f"Model {model} had {len(self.model_exception_map[model])} exception")
self.print_verbose()
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}"
@ -539,7 +569,7 @@ class Router:
# cooldown deployment
current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
self.print_verbose(f"updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}")
self.print_verbose(f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}")
if updated_fails > self.allowed_fails:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
@ -731,9 +761,12 @@ class Router:
def set_model_list(self, model_list: list):
self.model_list = model_list
# we add a 5 digit uuid to each model so load balancing between azure/gpt on api_base1 and api_base2 works
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
for model in self.model_list:
model["litellm_params"]["model"] += "-ModelID-" + str(random.randint(10000, 99999))[:5]
model_id = ""
for key in model["litellm_params"]:
model_id+= str(model["litellm_params"][key])
model["litellm_params"]["model"] += "-ModelID-" + model_id
self.model_names = [m["model_name"] for m in model_list]
def get_model_names(self):