forked from phoenix/litellm-mirror
fix(router.py): removing model id before making call
This commit is contained in:
parent
5ed957ebbe
commit
094144de58
2 changed files with 33 additions and 4 deletions
|
@ -244,10 +244,19 @@ class Router:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
data[k] = v
|
||||||
|
########## remove -ModelID-XXXX from model ##############
|
||||||
|
original_model_string = data["model"]
|
||||||
|
# Find the index of "ModelID" in the string
|
||||||
|
index_of_model_id = original_model_string.find("-ModelID")
|
||||||
|
# Remove everything after "-ModelID" if it exists
|
||||||
|
if index_of_model_id != -1:
|
||||||
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
|
else:
|
||||||
|
data["model"] = original_model_string
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -268,10 +277,19 @@ class Router:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, input=input)
|
deployment = self.get_available_deployment(model=model, input=input)
|
||||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
data[k] = v
|
||||||
|
########## remove -ModelID-XXXX from model ##############
|
||||||
|
original_model_string = data["model"]
|
||||||
|
# Find the index of "ModelID" in the string
|
||||||
|
index_of_model_id = original_model_string.find("-ModelID")
|
||||||
|
# Remove everything after "-ModelID" if it exists
|
||||||
|
if index_of_model_id != -1:
|
||||||
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
|
else:
|
||||||
|
data["model"] = original_model_string
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
||||||
|
|
||||||
|
@ -283,10 +301,19 @@ class Router:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, input=input)
|
deployment = self.get_available_deployment(model=model, input=input)
|
||||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
data[k] = v
|
||||||
|
########## remove -ModelID-XXXX from model ##############
|
||||||
|
original_model_string = data["model"]
|
||||||
|
# Find the index of "ModelID" in the string
|
||||||
|
index_of_model_id = original_model_string.find("-ModelID")
|
||||||
|
# Remove everything after "-ModelID" if it exists
|
||||||
|
if index_of_model_id != -1:
|
||||||
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
|
else:
|
||||||
|
data["model"] = original_model_string
|
||||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
||||||
|
|
||||||
async def async_function_with_fallbacks(self, *args, **kwargs):
|
async def async_function_with_fallbacks(self, *args, **kwargs):
|
||||||
|
|
|
@ -219,7 +219,7 @@ def test_acompletion_on_router():
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_acompletion_on_router()
|
# test_acompletion_on_router()
|
||||||
|
|
||||||
def test_function_calling_on_router():
|
def test_function_calling_on_router():
|
||||||
try:
|
try:
|
||||||
|
@ -272,6 +272,7 @@ def test_function_calling_on_router():
|
||||||
# test_function_calling_on_router()
|
# test_function_calling_on_router()
|
||||||
|
|
||||||
def test_aembedding_on_router():
|
def test_aembedding_on_router():
|
||||||
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
|
@ -296,3 +297,4 @@ def test_aembedding_on_router():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
test_aembedding_on_router()
|
Loading…
Add table
Add a link
Reference in a new issue