mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(feat) use api_base, api_key as model
This commit is contained in:
parent
9cef551623
commit
50733363ee
1 changed files with 40 additions and 7 deletions
|
@ -87,6 +87,7 @@ class Router:
|
|||
self.routing_strategy = routing_strategy
|
||||
self.fallbacks = fallbacks or litellm.fallbacks
|
||||
self.context_window_fallbacks = context_window_fallbacks or litellm.context_window_fallbacks
|
||||
self.model_exception_map = {} # dict to store model: list exceptions. self.exceptions = {"gpt-3.5": ["API KEY Error", "Rate Limit Error", "good morning error"]}
|
||||
|
||||
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
|
@ -171,7 +172,16 @@ class Router:
|
|||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
data["model"] = data["model"][:-14]
|
||||
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
self.print_verbose(f"completion model: {data['model']}")
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
except Exception as e:
|
||||
|
@ -207,9 +217,15 @@ class Router:
|
|||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
data["model"] = data["model"][:-14]
|
||||
self.print_verbose(f"acompletion model: {data['model']}")
|
||||
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -514,12 +530,26 @@ class Router:
|
|||
start_time, end_time # start/end time
|
||||
):
|
||||
try:
|
||||
exception = kwargs.get("exception", None)
|
||||
exception_type = type(exception)
|
||||
exception_status = getattr(exception, 'status_code', "")
|
||||
exception_cause = getattr(exception, '__cause__', "")
|
||||
exception_message = getattr(exception, 'message', "")
|
||||
exception_str = str(exception_type) + "Status: " + str(exception_status) + "Message: " + str(exception_cause) + str(exception_message) + "Full exception" + str(exception)
|
||||
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
||||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
|
||||
if metadata:
|
||||
deployment = metadata.get("deployment", None)
|
||||
self._set_cooldown_deployments(deployment)
|
||||
deployment_exceptions = self.model_exception_map.get(deployment, [])
|
||||
deployment_exceptions.append(exception_str)
|
||||
self.model_exception_map[deployment] = deployment_exceptions
|
||||
self.print_verbose("\nEXCEPTION FOR DEPLOYMENTS\n")
|
||||
self.print_verbose(self.model_exception_map)
|
||||
for model in self.model_exception_map:
|
||||
self.print_verbose(f"Model {model} had {len(self.model_exception_map[model])} exception")
|
||||
self.print_verbose()
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
|
||||
|
@ -539,7 +569,7 @@ class Router:
|
|||
# cooldown deployment
|
||||
current_fails = self.failed_calls.get_cache(key=deployment) or 0
|
||||
updated_fails = current_fails + 1
|
||||
self.print_verbose(f"updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}")
|
||||
self.print_verbose(f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}")
|
||||
if updated_fails > self.allowed_fails:
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
|
@ -731,9 +761,12 @@ class Router:
|
|||
|
||||
def set_model_list(self, model_list: list):
|
||||
self.model_list = model_list
|
||||
# we add a 5 digit uuid to each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
for model in self.model_list:
|
||||
model["litellm_params"]["model"] += "-ModelID-" + str(random.randint(10000, 99999))[:5]
|
||||
model_id = ""
|
||||
for key in model["litellm_params"]:
|
||||
model_id+= str(model["litellm_params"][key])
|
||||
model["litellm_params"]["model"] += "-ModelID-" + model_id
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
||||
def get_model_names(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue