forked from phoenix/litellm-mirror
Merge pull request #3153 from BerriAI/litellm_usage_based_routing_v2_improvements
usage based routing v2 improvements - unit testing + *NEW* async + sync 'pre_call_checks'
This commit is contained in:
commit
f1340b52dc
7 changed files with 723 additions and 43 deletions
|
@ -31,6 +31,7 @@ import copy
|
|||
from litellm._logging import verbose_router_logger
|
||||
import logging
|
||||
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -379,6 +380,9 @@ class Router:
|
|||
else:
|
||||
model_client = potential_model_client
|
||||
|
||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = litellm.completion(
|
||||
**{
|
||||
**data,
|
||||
|
@ -391,6 +395,7 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
|
@ -494,18 +499,20 @@ class Router:
|
|||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if (
|
||||
rpm_semaphore is not None
|
||||
and isinstance(rpm_semaphore, asyncio.Semaphore)
|
||||
and self.routing_strategy == "usage-based-routing-v2"
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await _response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await _response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
|
@ -580,6 +587,10 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = litellm.image_generation(
|
||||
**{
|
||||
**data,
|
||||
|
@ -658,7 +669,7 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.aimage_generation(
|
||||
response = litellm.aimage_generation(
|
||||
**{
|
||||
**data,
|
||||
"prompt": prompt,
|
||||
|
@ -667,6 +678,28 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -758,7 +791,7 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.atranscription(
|
||||
response = litellm.atranscription(
|
||||
**{
|
||||
**data,
|
||||
"file": file,
|
||||
|
@ -767,6 +800,28 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -981,7 +1036,8 @@ class Router:
|
|||
else:
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.atext_completion(
|
||||
|
||||
response = litellm.atext_completion(
|
||||
**{
|
||||
**data,
|
||||
"prompt": prompt,
|
||||
|
@ -991,6 +1047,27 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -1065,6 +1142,10 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = litellm.embedding(
|
||||
**{
|
||||
**data,
|
||||
|
@ -1150,7 +1231,7 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.aembedding(
|
||||
response = litellm.aembedding(
|
||||
**{
|
||||
**data,
|
||||
"input": input,
|
||||
|
@ -1159,6 +1240,28 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -1716,6 +1819,38 @@ class Router:
|
|||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
"""
|
||||
Mimics 'async_routing_strategy_pre_call_checks'
|
||||
|
||||
Ensures consistent update rpm implementation for 'usage-based-routing-v2'
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
|
||||
"""
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
response = _callback.pre_call_check(deployment)
|
||||
|
||||
async def async_routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
"""
|
||||
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
|
||||
|
||||
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
|
||||
"""
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
response = await _callback.async_pre_call_check(deployment)
|
||||
|
||||
def set_client(self, model: dict):
|
||||
"""
|
||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||
|
@ -2704,6 +2839,7 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
)
|
||||
|
||||
return deployment
|
||||
|
||||
def get_available_deployment(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue