From 094144de58de7604deb47c525bd3ae72576560d4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 28 Nov 2023 10:09:45 -0800 Subject: [PATCH] fix(router.py): removing model id before making call --- litellm/router.py | 33 ++++++++++++++++++++++++++++++--- litellm/tests/test_router.py | 4 +++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index c725b0b326..e9632d686c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -244,10 +244,19 @@ class Router: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, messages=messages) - data = deployment["litellm_params"] + data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params data[k] = v + ########## remove -ModelID-XXXX from model ############## + original_model_string = data["model"] + # Find the index of "ModelID" in the string + index_of_model_id = original_model_string.find("-ModelID") + # Remove everything after "-ModelID" if it exists + if index_of_model_id != -1: + data["model"] = original_model_string[:index_of_model_id] + else: + data["model"] = original_model_string # call via litellm.completion() return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore except Exception as e: @@ -268,10 +277,19 @@ class Router: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, input=input) kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) - data = deployment["litellm_params"] + data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params data[k] = v + ########## remove -ModelID-XXXX from model ############## + original_model_string = data["model"] + # Find the index of "ModelID" in the string + index_of_model_id = original_model_string.find("-ModelID") + # Remove everything after "-ModelID" if it exists + if index_of_model_id != -1: + data["model"] = original_model_string[:index_of_model_id] + else: + data["model"] = original_model_string # call via litellm.embedding() return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) @@ -283,10 +301,19 @@ class Router: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, input=input) kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) - data = deployment["litellm_params"] + data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params data[k] = v + ########## remove -ModelID-XXXX from model ############## + original_model_string = data["model"] + # Find the index of "ModelID" in the string + index_of_model_id = original_model_string.find("-ModelID") + # Remove everything after "-ModelID" if it exists + if index_of_model_id != -1: + data["model"] = original_model_string[:index_of_model_id] + else: + data["model"] = original_model_string return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) async def async_function_with_fallbacks(self, *args, **kwargs): diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 92380034af..46c9ae97c2 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -219,7 +219,7 @@ def test_acompletion_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -test_acompletion_on_router() +# test_acompletion_on_router() def test_function_calling_on_router(): try: @@ -272,6 +272,7 @@ def test_function_calling_on_router(): # test_function_calling_on_router() def test_aembedding_on_router(): + litellm.set_verbose = True try: model_list = [ { @@ -296,3 +297,4 @@ def test_aembedding_on_router(): except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") +test_aembedding_on_router() \ No newline at end of file