mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
(fix) router: allow same model/name
This commit is contained in:
parent
ba228a9e0a
commit
4265f9b2ef
1 changed files with 8 additions and 4 deletions
|
@ -166,11 +166,11 @@ class Router:
|
|||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"]
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
|
||||
data["model"] = data["model"][:-14]
|
||||
self.print_verbose(f"completion model: {data['model']}")
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
except Exception as e:
|
||||
|
@ -202,10 +202,11 @@ class Router:
|
|||
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"]
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
data["model"] = data["model"][:-14]
|
||||
self.print_verbose(f"acompletion model: {data['model']}")
|
||||
|
||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
|
@ -722,6 +723,9 @@ class Router:
|
|||
|
||||
def set_model_list(self, model_list: list):
|
||||
self.model_list = model_list
|
||||
# we add a 5 digit uuid to each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
for model in self.model_list:
|
||||
model["litellm_params"]["model"] += "-ModelID-" + str(random.randint(10000, 99999))[:5]
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
||||
def get_model_names(self):
|
||||
|
@ -757,7 +761,7 @@ class Router:
|
|||
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
||||
for deployment in deployments_to_remove:
|
||||
healthy_deployments.remove(deployment)
|
||||
self.print_verbose(f"healthy deployments: {healthy_deployments}")
|
||||
self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}")
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError("No models available")
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue