diff --git a/litellm/router.py b/litellm/router.py index 39aa0f41c..9da7488ca 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -686,48 +686,72 @@ class Router: is_async: Optional[bool] = True, **kwargs, ) -> Union[List[float], None]: - # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment( - model=model, - input=input, - specific_deployment=kwargs.pop("specific_deployment", None), - ) - kwargs.setdefault("metadata", {}).update( - {"model_group": model, "deployment": deployment["litellm_params"]["model"]} - ) - data = deployment["litellm_params"].copy() - kwargs["model_info"] = deployment.get("model_info", {}) - for k, v in self.default_litellm_params.items(): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._aembedding + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = await self.async_function_with_fallbacks(**kwargs) + return response + except Exception as e: + raise e + + async def _aembedding(self, input: Union[str, List], model: str, **kwargs): + try: + self.print_verbose( + f"Inside _aembedding()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + {"deployment": deployment["litellm_params"]["model"]} + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" - ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client - - return await litellm.aembedding( - **{ - **data, - "input": input, - "caching": self.cache_responses, - "client": model_client, - **kwargs, - } - ) + self.total_calls[model_name] += 1 + response = await litellm.aembedding( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 + return response + except Exception as e: + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e async def async_function_with_fallbacks(self, *args, **kwargs): """ @@ -1200,65 +1224,6 @@ class Router: self.print_verbose(f"retrieve cooldown models: {cooldown_models}") return cooldown_models - def _start_health_check_thread(self): - """ - Starts a separate thread to perform health checks periodically. - """ - health_check_thread = threading.Thread( - target=self._perform_health_checks, daemon=True - ) - health_check_thread.start() - - def _perform_health_checks(self): - """ - Periodically performs health checks on the servers. - Updates the list of healthy servers accordingly. - """ - while True: - self.healthy_deployments = self._health_check() - # Adjust the time interval based on your needs - time.sleep(15) - - def _health_check(self): - """ - Performs a health check on the deployments - Returns the list of healthy deployments - """ - healthy_deployments = [] - for deployment in self.model_list: - litellm_args = deployment["litellm_params"] - try: - start_time = time.time() - litellm.completion( - messages=[{"role": "user", "content": ""}], - max_tokens=1, - **litellm_args, - ) # hit the server with a blank message to see how long it takes to respond - end_time = time.time() - response_time = end_time - start_time - logging.debug(f"response_time: {response_time}") - healthy_deployments.append((deployment, response_time)) - healthy_deployments.sort(key=lambda x: x[1]) - except Exception as e: - pass - return healthy_deployments - - def weighted_shuffle_by_latency(self, items): - # Sort the items by latency - sorted_items = sorted(items, key=lambda x: x[1]) - # Get only the latencies - latencies = [i[1] for i in sorted_items] - # Calculate the sum of all latencies - total_latency = sum(latencies) - # Calculate the weight for each latency (lower latency = higher weight) - weights = [total_latency - latency for latency in latencies] - # Get a weighted random item - if sum(weights) == 0: - chosen_item = random.choice(sorted_items)[0] - else: - chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0] - return chosen_item - def set_client(self, model: dict): """ Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index 9ed458804..94c6de983 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -151,7 +151,6 @@ class LowestTPMLoggingHandler(CustomLogger): ## if healthy deployment not yet used if d["model_info"]["id"] not in all_deployments: all_deployments[d["model_info"]["id"]] = 0 - input_tokens = token_counter(messages=messages, text=input) for item, item_tpm in all_deployments.items(): ## get the item from model list diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index d7ffd446d..6d3cd6e43 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -227,6 +227,60 @@ async def test_async_fallbacks(): # test_async_fallbacks() +@pytest.mark.asyncio +async def test_async_fallbacks_embeddings(): + litellm.set_verbose = False + model_list = [ + { # list of model deployments + "model_name": "bad-azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { # list of model deployments + "model_name": "good-azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + + router = Router( + model_list=model_list, + fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}], + set_verbose=False, + ) + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + user_message = "Hello, how are you?" + input = [user_message] + try: + kwargs = {"model": "bad-azure-embedding-model", "input": input} + response = await router.aembedding(**kwargs) + print(f"customHandler.previous_models: {customHandler.previous_models}") + await asyncio.sleep( + 0.05 + ) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback + router.reset() + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + finally: + router.reset() + + def test_dynamic_fallbacks_sync(): """ Allow setting the fallback in the router.completion() call. diff --git a/litellm/tests/test_tpm_rpm_routing.py b/litellm/tests/test_tpm_rpm_routing.py index 6f45d1658..3ce43f66e 100644 --- a/litellm/tests/test_tpm_rpm_routing.py +++ b/litellm/tests/test_tpm_rpm_routing.py @@ -122,12 +122,16 @@ def test_get_available_deployments(): ## CHECK WHAT'S SELECTED ## print( lowest_tpm_logger.get_available_deployments( - model_group=model_group, healthy_deployments=model_list + model_group=model_group, + healthy_deployments=model_list, + input=["Hello world"], ) ) assert ( lowest_tpm_logger.get_available_deployments( - model_group=model_group, healthy_deployments=model_list + model_group=model_group, + healthy_deployments=model_list, + input=["Hello world"], )["model_info"]["id"] == "5678" )