diff --git a/litellm/router.py b/litellm/router.py index 5d1adf2b5..b100cfbb0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1564,6 +1564,23 @@ class Router: ): model["litellm_params"]["tpm"] = model.get("tpm") + #### VALIDATE MODEL ######## + # check if model provider in supported providers + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model["litellm_params"]["model"], + custom_llm_provider=model["litellm_params"].get( + "custom_llm_provider", None + ), + ) + + if custom_llm_provider not in litellm.provider_list: + raise Exception(f"Unsupported provider - {custom_llm_provider}") + self.set_client(model=model) self.print_verbose(f"\nInitialized Model List {self.model_list}") diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index 4b492cded..9ed458804 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -182,7 +182,7 @@ class LowestTPMLoggingHandler(CustomLogger): break elif ( item_tpm + input_tokens > _deployment_tpm - or rpm_dict[item] + 1 >= _deployment_rpm + or rpm_dict[item] + 1 > _deployment_rpm ): # if user passed in tpm / rpm in the model_list continue elif item_tpm < lowest_tpm: diff --git a/litellm/tests/test_router_caching.py b/litellm/tests/test_router_caching.py index d93288fce..74a572c46 100644 --- a/litellm/tests/test_router_caching.py +++ b/litellm/tests/test_router_caching.py @@ -209,47 +209,3 @@ async def test_acompletion_caching_on_router_caching_groups(): except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") - - -def test_usage_based_routing_completion(): - litellm.set_verbose = True - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0301", - "api_key": os.getenv("OPENAI_API_KEY"), - "custom_llm_provider": "Custom-LLM", - }, - "tpm": 10000, - "rpm": 5, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0301", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 10000, - "rpm": 5, - }, - ] - router = Router( - model_list=model_list, routing_strategy="usage-based-routing", set_verbose=False - ) - max_requests = 5 - while max_requests > 0: - try: - router.completion( - model="gpt-3.5-turbo", - messages=[{"content": "write a one sentence poem.", "role": "user"}], - ) - except ValueError as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") - finally: - max_requests -= 1 - router.reset() - - -test_usage_based_routing_completion()