forked from phoenix/litellm-mirror
fix(router.py): make get_cooldown_deployment logic async
This commit is contained in:
parent
a47a719caa
commit
2531701a2a
2 changed files with 228 additions and 59 deletions
|
@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
|
|||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
|
||||
import datetime as datetime_og
|
||||
import logging, asyncio
|
||||
import inspect, concurrent
|
||||
from openai import AsyncOpenAI
|
||||
|
@ -414,7 +414,7 @@ class Router:
|
|||
verbose_router_logger.debug(
|
||||
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = self.get_available_deployment(
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
messages=messages,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
|
@ -1605,7 +1605,7 @@ class Router:
|
|||
if deployment is None:
|
||||
return
|
||||
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
current_minute = datetime.now(datetime_og.UTC).strftime("%H-%M")
|
||||
# get current fails for deployment
|
||||
# update the number of failed calls
|
||||
# if it's > allowed fails
|
||||
|
@ -1643,6 +1643,22 @@ class Router:
|
|||
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||
)
|
||||
|
||||
async def _async_get_cooldown_deployments(self):
|
||||
"""
|
||||
Async implementation of '_get_cooldown_deployments'
|
||||
"""
|
||||
current_minute = datetime.now(datetime_og.UTC).strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models"
|
||||
|
||||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
||||
|
||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
def _get_cooldown_deployments(self):
|
||||
"""
|
||||
Get the list of models being cooled down for this minute
|
||||
|
@ -2405,7 +2421,7 @@ class Router:
|
|||
|
||||
return _returned_deployments
|
||||
|
||||
def get_available_deployment(
|
||||
def _common_checks_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
|
@ -2413,10 +2429,8 @@ class Router:
|
|||
specific_deployment: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Returns the deployment based on routing strategy
|
||||
Common checks for 'get_available_deployment' across sync + async call.
|
||||
"""
|
||||
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
|
||||
# When this was no explicit we had several issues with fallbacks timing out
|
||||
if specific_deployment == True:
|
||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||
for deployment in self.model_list:
|
||||
|
@ -2456,6 +2470,111 @@ class Router:
|
|||
f"initial list of deployments: {healthy_deployments}"
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
|
||||
)
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(f"No healthy deployment available, passed model={model}")
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
|
||||
return model, healthy_deployments
|
||||
|
||||
async def async_get_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Async implementation of 'get_available_deployments'.
|
||||
|
||||
Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps).
|
||||
"""
|
||||
if (
|
||||
self.routing_strategy != "usage-based-routing-v2"
|
||||
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
||||
return self.get_available_deployment(
|
||||
model=model,
|
||||
messages=messages,
|
||||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
model, healthy_deployments = self._common_checks_available_deployment(
|
||||
model=model,
|
||||
messages=messages,
|
||||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
|
||||
# filter out the deployments currently cooling down
|
||||
deployments_to_remove = []
|
||||
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||
cooldown_deployments = await self._async_get_cooldown_deployments()
|
||||
verbose_router_logger.debug(
|
||||
f"async cooldown deployments: {cooldown_deployments}"
|
||||
)
|
||||
# Find deployments in model_list whose model_id is cooling down
|
||||
for deployment in healthy_deployments:
|
||||
deployment_id = deployment["model_info"]["id"]
|
||||
if deployment_id in cooldown_deployments:
|
||||
deployments_to_remove.append(deployment)
|
||||
# remove unhealthy deployments from healthy deployments
|
||||
for deployment in deployments_to_remove:
|
||||
healthy_deployments.remove(deployment)
|
||||
|
||||
# filter pre-call checks
|
||||
if self.enable_pre_call_checks and messages is not None:
|
||||
healthy_deployments = self._pre_call_checks(
|
||||
model=model, healthy_deployments=healthy_deployments, messages=messages
|
||||
)
|
||||
|
||||
if (
|
||||
self.routing_strategy == "usage-based-routing-v2"
|
||||
and self.lowesttpm_logger_v2 is not None
|
||||
):
|
||||
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError(
|
||||
f"No deployments available for selected model, passed model={model}"
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
)
|
||||
return deployment
|
||||
|
||||
def get_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Returns the deployment based on routing strategy
|
||||
"""
|
||||
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
|
||||
# When this was no explicit we had several issues with fallbacks timing out
|
||||
|
||||
model, healthy_deployments = self._common_checks_available_deployment(
|
||||
model=model,
|
||||
messages=messages,
|
||||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
|
||||
# filter out the deployments currently cooling down
|
||||
deployments_to_remove = []
|
||||
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||
|
@ -2476,16 +2595,6 @@ class Router:
|
|||
model=model, healthy_deployments=healthy_deployments, messages=messages
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
|
||||
)
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(f"No healthy deployment available, passed model={model}")
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployment = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model, healthy_deployments=healthy_deployments
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue