Merge pull request #3153 from BerriAI/litellm_usage_based_routing_v2_improvements

usage based routing v2 improvements - unit testing + *NEW* async + sync 'pre_call_checks'
This commit is contained in:
Krish Dholakia 2024-04-18 22:16:16 -07:00 committed by GitHub
commit f1340b52dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 723 additions and 43 deletions

View file

@ -31,6 +31,7 @@ import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
from litellm.integrations.custom_logger import CustomLogger
class Router:
@ -379,6 +380,9 @@ class Router:
else:
model_client = potential_model_client
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.completion(
**{
**data,
@ -391,6 +395,7 @@ class Router:
verbose_router_logger.info(
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
@ -494,18 +499,20 @@ class Router:
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if (
rpm_semaphore is not None
and isinstance(rpm_semaphore, asyncio.Semaphore)
and self.routing_strategy == "usage-based-routing-v2"
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await _response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await _response
self.success_calls[model_name] += 1
@ -580,6 +587,10 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.image_generation(
**{
**data,
@ -658,7 +669,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.aimage_generation(
response = litellm.aimage_generation(
**{
**data,
"prompt": prompt,
@ -667,6 +678,28 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
@ -758,7 +791,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.atranscription(
response = litellm.atranscription(
**{
**data,
"file": file,
@ -767,6 +800,28 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
@ -981,7 +1036,8 @@ class Router:
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.atext_completion(
response = litellm.atext_completion(
**{
**data,
"prompt": prompt,
@ -991,6 +1047,27 @@ class Router:
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
@ -1065,6 +1142,10 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.embedding(
**{
**data,
@ -1150,7 +1231,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.aembedding(
response = litellm.aembedding(
**{
**data,
"input": input,
@ -1159,6 +1240,28 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
@ -1716,6 +1819,38 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def routing_strategy_pre_call_checks(self, deployment: dict):
"""
Mimics 'async_routing_strategy_pre_call_checks'
Ensures consistent update rpm implementation for 'usage-based-routing-v2'
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = _callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks(self, deployment: dict):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = await _callback.async_pre_call_check(deployment)
def set_client(self, model: dict):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
@ -2704,6 +2839,7 @@ class Router:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
)
return deployment
def get_available_deployment(