fix(router.py): removing model id before making call

This commit is contained in:
Krrish Dholakia 2023-11-28 10:09:45 -08:00
parent 5ed957ebbe
commit 094144de58
2 changed files with 33 additions and 4 deletions

View file

@ -244,10 +244,19 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
@ -268,10 +277,19 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input)
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
data = deployment["litellm_params"]
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
@ -283,10 +301,19 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input)
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
data = deployment["litellm_params"]
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
async def async_function_with_fallbacks(self, *args, **kwargs):