(feat) router - remove confusing -ModelID-XXXX

This commit is contained in:
ishaan-jaff 2023-12-15 07:06:02 +05:30
parent 18c8b298bd
commit b6a78cd349

View file

@ -205,17 +205,6 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
self.print_verbose(f"completion model: {original_model_string}")
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
except Exception as e:
@ -246,31 +235,23 @@ class Router:
**kwargs):
try:
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
original_model_string = None # set a default for this variable
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[original_model_string] +=1
self.total_calls[model_name] +=1
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[original_model_string] +=1
self.success_calls[model_name] +=1
return response
except Exception as e:
if original_model_string is not None:
self.fail_calls[original_model_string] +=1
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
def text_completion(self,
@ -290,15 +271,6 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
@ -363,15 +335,6 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -389,15 +352,7 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -681,9 +636,10 @@ class Router:
model_name = kwargs.get('model', None) # i.e. gpt35turbo
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
deployment_id = kwargs.get("litellm_params", {}).get("model_info").get("id")
self._set_cooldown_deployments(deployment_id) # setting deployment_id in cooldown deployments
if metadata:
deployment = metadata.get("deployment", None)
self._set_cooldown_deployments(deployment)
deployment_exceptions = self.model_exception_map.get(deployment, [])
deployment_exceptions.append(exception_str)
self.model_exception_map[deployment] = deployment_exceptions
@ -1094,12 +1050,6 @@ class Router:
############ End of initializing Clients for OpenAI/Azure ###################
self.deployment_names.append(model["litellm_params"]["model"])
model_id = ""
for key in model["litellm_params"]:
if key != "api_key" and key != "metadata":
model_id+= str(model["litellm_params"][key])
model["litellm_params"]["model"] += "-ModelID-" + model_id
self.print_verbose(f"\n Initialized Model List {self.model_list}")
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
@ -1160,8 +1110,8 @@ class Router:
if specific_deployment == True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list:
cleaned_model = litellm.utils.remove_model_id(deployment.get("litellm_params").get("model"))
if cleaned_model == model:
deployment_model = deployment.get("litellm_params").get("model")
if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name
return deployment
@ -1174,24 +1124,27 @@ class Router:
## get healthy deployments
### get all deployments
### filter out the deployments currently cooling down
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
# filter out the deployments currently cooling down
deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments()
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
cooldown_deployments = self._get_cooldown_deployments()
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
### FIND UNHEALTHY DEPLOYMENTS
# Find deployments in model_list whose model_id is cooling down
for deployment in healthy_deployments:
deployment_name = deployment["litellm_params"]["model"]
if deployment_name in cooldown_deployments:
deployment_id = deployment["model_info"]["id"]
if deployment_id in cooldown_deployments:
deployments_to_remove.append(deployment)
### FILTER OUT UNHEALTHY DEPLOYMENTS
# remove unhealthy deployments from healthy deployments
for deployment in deployments_to_remove:
healthy_deployments.remove(deployment)
self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}")
if len(healthy_deployments) == 0:
raise ValueError("No models available")